fix for bf16 (#22)

This commit is contained in:
arampacha 2021-07-06 23:44:48 +03:00 committed by GitHub
parent 9615befa43
commit 420ab78f56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 362 additions and 94 deletions

View File

@ -2,43 +2,48 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "84b1a438-cf1d-402e-a56f-2c4f9dd5ad51",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, FlaxGPTNeoForCausalLM"
"from transformers import AutoTokenizer, FlaxGPTNeoForCausalLM, AutoModelForMaskedLM"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "7d50cd18-33ed-4b67-82ad-5c48eb9a9b36",
"metadata": {},
"outputs": [],
"source": [
"model_ckpt = 'EleutherAI/gpt-neo-125M'"
"from pathlib import Path\n",
"# model_ckpt = 'EleutherAI/gpt-neo-125M'\n",
"model_ckpt = (Path.home()/'gpt-neo-125M-code-clippy').as_posix()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"id": "2ec0c4cc-a1bc-4dda-bd0b-72891b519b39",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
"id": "065c03c3-2e4a-4f20-a30d-25ada1418b18",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57ed89aae89749aba08b715ec2258b82",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/501M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:absl:Starting the local TPU driver.\n",
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
]
}
],
"source": [
@ -48,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 11,
"id": "e2f9fb26-2e26-4f57-aa93-e349475203f3",
"metadata": {},
"outputs": [],
@ -58,22 +63,23 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 27,
"id": "75c0c2f6-47ad-41c3-8c66-a1ceeecde061",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"class Model(nn.Module):\n",
"x = np.random.randn(10, 10)\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 28,
"id": "666977a1-de0d-4900-bf61-ae2b672e51bc",
"metadata": {},
"outputs": [],
@ -84,17 +90,17 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 29,
"id": "249e4a8a-7a7e-4e8b-83be-7184a4c0dd0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 19, 50257)"
"(1, 40, 50257)"
]
},
"execution_count": 43,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
@ -106,19 +112,21 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 30,
"id": "eee873f5-073c-4cbe-8b15-114ea18b2de8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[ 198, 11748, 28034, 198, 6738, 28034, 1330, 299,\n",
" 77, 198, 198, 4871, 9104, 7, 20471, 13,\n",
" 26796, 2599, 198]], dtype=int32)"
"DeviceArray([[ 198, 11748, 299, 32152, 355, 45941, 198, 11748,\n",
" 19798, 292, 355, 279, 67, 198, 11748, 2603,\n",
" 29487, 8019, 13, 9078, 29487, 355, 458, 83,\n",
" 198, 198, 87, 796, 45941, 13, 25120, 13,\n",
" 25192, 77, 7, 940, 11, 838, 8, 198]], dtype=int32)"
]
},
"execution_count": 37,
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
@ -129,55 +137,55 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 37,
"id": "82666225-3ab7-405f-9536-4e9e3085be24",
"metadata": {},
"outputs": [],
"source": [
"out = model.generate(input_ids,\n",
" max_length=200, \n",
"# num_beams=5,\n",
" num_beams=1,\n",
" pad_token_id = tokenizer.pad_token_id\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 38,
"id": "c6cc862b-23ef-417d-ae83-1b2eafb0460f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FlaxGreedySearchOutput(sequences=DeviceArray([[ 198, 11748, 28034, 198, 6738, 28034, 1330, 299,\n",
" 77, 198, 198, 4871, 9104, 7, 20471, 13,\n",
" 26796, 2599, 198, 220, 220, 220, 825, 11593,\n",
" 15003, 834, 7, 944, 11, 1438, 11, 2746,\n",
" 11, 12429, 46265, 22046, 2599, 198, 220, 220,\n",
" 220, 220, 220, 220, 220, 2208, 7, 17633,\n",
" 11, 2116, 737, 834, 15003, 834, 7, 3672,\n",
" 11, 2746, 11, 12429, 46265, 22046, 8, 198,\n",
" 220, 220, 220, 220, 220, 220, 220, 2116,\n",
" 13, 3672, 796, 1438, 198, 220, 220, 220,\n",
" 220, 220, 220, 220, 2116, 13, 19849, 796,\n",
" 2746, 198, 220, 220, 220, 220, 220, 220,\n",
" 220, 2116, 13, 46265, 22046, 796, 479, 86,\n",
" 22046, 198, 220, 220, 220, 220, 220, 220,\n",
" 220, 2116, 13, 3672, 62, 40290, 796, 705,\n",
" 19849, 6, 198, 220, 220, 220, 220, 220,\n",
" 220, 220, 2116, 13, 3672, 62, 37333, 844,\n",
" 796, 705, 19849, 62, 3672, 6, 198, 220,\n",
" 220, 220, 220, 220, 220, 220, 2116, 13,\n",
" 3672, 62, 40290, 62, 40290, 796, 705, 19849,\n",
" 62, 3672, 62, 40290, 6, 198, 220, 220,\n",
" 220, 220, 220, 220, 220, 2116, 13, 3672,\n",
" 62, 37333, 844, 62, 40290, 796, 705, 19849,\n",
" 62, 3672, 62, 37333, 844, 6, 198, 220,\n",
" 220, 220, 220, 220, 220, 220, 2116, 13]], dtype=int32))"
"FlaxGreedySearchOutput(sequences=DeviceArray([[ 198, 11748, 299, 32152, 355, 45941, 198, 11748,\n",
" 19798, 292, 355, 279, 67, 198, 11748, 2603,\n",
" 29487, 8019, 13, 9078, 29487, 355, 458, 83,\n",
" 198, 198, 87, 796, 45941, 13, 25120, 13,\n",
" 25192, 77, 7, 940, 11, 838, 8, 198,\n",
" 88, 796, 45941, 13, 25120, 13, 25192, 77,\n",
" 7, 940, 11, 838, 8, 198, 198, 2,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
" 220, 220, 220, 220, 220, 220, 220, 220]], dtype=int32))"
]
},
"execution_count": 50,
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
@ -188,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 39,
"id": "8f6c746a-2d56-4da4-acb5-e066a6a230f2",
"metadata": {},
"outputs": [
@ -197,26 +205,262 @@
"output_type": "stream",
"text": [
"\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self, name, model, **kwargs):\n",
" super(Model, self).__init__(name, model, **kwargs)\n",
" self.name = name\n",
" self.model = model\n",
" self.kwargs = kwargs\n",
" self.name_prefix ='model'\n",
" self.name_suffix ='model_name'\n",
" self.name_prefix_prefix ='model_name_prefix'\n",
" self.name_suffix_prefix ='model_name_suffix'\n",
" self.\n"
"x = np.random.randn(10, 10)\n",
"y = np.random.randn(10, 10)\n",
"\n",
"# \n"
]
}
],
"source": [
"print(tokenizer.decode(out[0][0]))"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "b6effeaa-2237-47bc-b0f6-940c4e274c38",
"metadata": {},
"outputs": [],
"source": [
"from transformers import GPTNeoForCausalLM, AutoModelForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "3665a3fd-5d92-45e8-8fde-393ec803383a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n",
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
"\n",
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['lm_head.weight', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.4.attn.attention.bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = GPTNeoForCausalLM.from_pretrained(model_ckpt, from_flax=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fa23301-6f5d-40d5-b614-f14330df894a",
"metadata": {},
"outputs": [],
"source": [
"from transormers import AutoModelForMaskedLM\n",
"model = AutoModelForMaskedLM.from_pretrained(model_ckpt, from_flax=True)\n",
"model.save_pretrained(model_ckpt, save_config=False)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "35114633-bb5f-4c00-ae16-540a7fabb126",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
"\n",
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['lm_head.weight', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.4.attn.attention.bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained(model_ckpt, from_flax=True)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "15cd3853-1308-46e1-90c1-52b3af0fcac4",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"my_list = ['banana', 'apple', 'orange', 'pineapple']\n",
"\n",
"#Using brute force method\n",
"last_element = my_list[len(my_list) - 1]\n",
"\n",
"#Using negative indeces\n",
"last_element = my_list[-1]\n",
"\n",
"#Using pop method\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "2f2fc69a-f8f5-4859-bb2e-5c33e63f064a",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"def get_vowels(string):\n",
" return [vowel for vowel in string if vowel in 'aeiou'] \n",
"\n",
"print(\"Vowels are:\",\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "517aa451-3316-45fc-97ab-1a9a52ba55b6",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"import time\n",
"\n",
"start_time = time.time()\n",
"\n",
"total = 0\n",
"for i in range(10):\n",
" total += i\n",
"print(\"Sum:\", total)\n",
"\n",
"end_time = time.time()\n",
"time_taken = \"\"\""
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "749f6c3d-e1a4-4df7-be81-086024345766",
"metadata": {},
"outputs": [],
"source": [
"inputs = tokenizer(prompt, return_tensors='pt')\n",
"input_ids = inputs.input_ids"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "6f60e3f0-d051-4df1-8258-7bc479486603",
"metadata": {},
"outputs": [],
"source": [
"out = model.generate(input_ids,\n",
" max_length=200, \n",
" num_beams=1,\n",
" pad_token_id = tokenizer.pad_token_id\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "9d17aec3-e42a-43d6-a535-2eeaad2a9c78",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import time\n",
"\n",
"start_time = time.time()\n",
"\n",
"total = 0\n",
"for i in range(10):\n",
" total += i\n",
"print(\"Sum:\", total)\n",
"\n",
"end_time = time.time()\n",
"time_taken = time.time()\n",
"\n",
"# \n"
]
}
],
"source": [
"print(tokenizer.decode(out[0]))"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "57574549-bd1d-46b0-98ca-352662f735d2",
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained(model_ckpt, save_config=False)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0578148d-497d-422f-b7fb-b644d2a7c62f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-06 15:02:08.590730: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2021-07-06 15:02:08.590769: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
}
],
"source": [
"from transformers import TrainingArguments"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "63ffd8ff-8e95-4aad-9068-b27fd8c129bb",
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import fields"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a9a89020-e7b0-4826-88e6-8ac4f4c6f89e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Field(name='skip_memory_metrics',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object at 0x7f9a12926af0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether or not to skip adding of memory profiler reports to metrics.'}),_field_type=_FIELD)\n"
]
}
],
"source": [
"for f in fields(TrainingArguments):\n",
" if f.name == \"skip_memory_metrics\":\n",
" print(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "beda87b7-c461-4f92-8988-4255a8e79cf9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -36,6 +36,7 @@ from tqdm import tqdm
import jax
import jax.numpy as jnp
import jax.profiler
import optax
import transformers
from flax import jax_utils, traverse_util
@ -350,6 +351,14 @@ def main():
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
# Convert weights to bf16 manually
# TODO: remove when .from_pretrained handles it properly
def to_bf16(t):
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
if model_args.dtype == "bfloat16":
model.params = to_bf16(model.params)
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
@ -492,18 +501,21 @@ def main():
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# create optimizer
if training_args.adafactor:
optimizer = optax.adafactor(3e-4)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :]
@ -557,6 +569,9 @@ def main():
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
if not training_args.skip_memory_metrics:
server = jax.profiler.start_server(9999)
train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)

View File

@ -6,15 +6,20 @@
--dataset_config_name="python" \
--text_column_name="func_code_string" \
--do_train --do_eval \
--block_size="128" \
--per_device_train_batch_size="64" \
--per_device_eval_batch_size="128" \
--block_size="1024" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--preprocessing_num_workers="8" \
--learning_rate="5e-3" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--weight_decay="0.01" \
--overwrite_output_dir \
--num_train_epochs="1" \
--push_to_hub="False"
--push_to_hub="False" \
--dtype="float32" \
--skip_memory_metrics="False"
# --max_train_samples="10000" \
# --max_eval_samples="1000" \

View File

@ -6,16 +6,20 @@
--dataset_config_name="python" \
--text_column_name="func_code_string" \
--do_train --do_eval \
--block_size="128" \
--per_device_train_batch_size="1" \
--block_size="1024" \
--per_device_train_batch_size="2" \
--per_device_eval_batch_size="2" \
--preprocessing_num_workers="8" \
--dtype="bfloat16" \
--learning_rate="5e-3" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--weight_decay="0.01" \
--overwrite_output_dir \
--num_train_epochs="1" \
--push_to_hub="False"
--push_to_hub="False" \
--dtype="bfloat16" \
--adafactor="False" \
--skip_memory_metrics="False"
# --max_train_samples="10000" \
# --max_eval_samples="1000" \