gpt-code-clippy/flax-gpt-neo-clm-v3.ipynb
2021-07-07 16:21:24 +00:00

838 lines
36 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "6d0d0fdf-def5-4c93-b1e9-4223d70c3c22",
"metadata": {},
"outputs": [],
"source": [
"#!/usr/bin/env python\n",
"# coding=utf-8\n",
"# Copyright 2021 The HuggingFace Team All rights reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\"\"\"\n",
"Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n",
"\n",
"Here is the full list of checkpoints on the hub that can be fine-tuned by this script:\n",
"https://huggingface.co/models?filter=causal-lm\n",
"\"\"\"\n",
"# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n",
"\n",
"import logging\n",
"import math\n",
"import os\n",
"import sys\n",
"import time\n",
"from dataclasses import dataclass, field\n",
"from pathlib import Path\n",
"from typing import Callable, Optional\n",
"\n",
"import datasets\n",
"from datasets import Dataset, load_dataset\n",
"from tqdm.auto import tqdm\n",
"\n",
"import json\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import transformers\n",
"from flax import jax_utils, traverse_util\n",
"from flax.jax_utils import unreplicate\n",
"from flax.training import train_state\n",
"from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key\n",
"from flax.serialization import to_bytes, from_bytes\n",
"from transformers import (\n",
" CONFIG_MAPPING,\n",
" FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n",
" AutoConfig,\n",
" AutoTokenizer,\n",
" FlaxAutoModelForCausalLM,\n",
" HfArgumentParser,\n",
" TrainingArguments,\n",
" is_tensorboard_available,\n",
")\n",
"from transformers.testing_utils import CaptureLogger\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\n",
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n",
"\n",
"\n",
"@dataclass\n",
"class ModelArguments:\n",
" \"\"\"\n",
" Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n",
" \"\"\"\n",
"\n",
" model_name_or_path: Optional[str] = field(\n",
" default=None,\n",
" metadata={\n",
" \"help\": \"The model checkpoint for weights initialization.\"\n",
" \"Don't set if you want to train a model from scratch.\"\n",
" },\n",
" )\n",
" model_type: Optional[str] = field(\n",
" default=None,\n",
" metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n",
" )\n",
" config_name: Optional[str] = field(\n",
" default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n",
" )\n",
" tokenizer_name: Optional[str] = field(\n",
" default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n",
" )\n",
" cache_dir: Optional[str] = field(\n",
" default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n",
" )\n",
" use_fast_tokenizer: bool = field(\n",
" default=True,\n",
" metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n",
" )\n",
" dtype: Optional[str] = field(\n",
" default=\"float32\",\n",
" metadata={\n",
" \"help\": \"Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`.\"\n",
" },\n",
" )\n",
"\n",
"\n",
"@dataclass\n",
"class DataTrainingArguments:\n",
" \"\"\"\n",
" Arguments pertaining to what data we are going to input our model for training and eval.\n",
" \"\"\"\n",
"\n",
" dataset_name: Optional[str] = field(\n",
" default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n",
" )\n",
" dataset_config_name: Optional[str] = field(\n",
" default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n",
" )\n",
" train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n",
" validation_file: Optional[str] = field(\n",
" default=None,\n",
" metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n",
" )\n",
" max_train_samples: Optional[int] = field(\n",
" default=None,\n",
" metadata={\n",
" \"help\": \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n",
" \"value if set.\"\n",
" },\n",
" )\n",
" max_eval_samples: Optional[int] = field(\n",
" default=None,\n",
" metadata={\n",
" \"help\": \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n",
" \"value if set.\"\n",
" },\n",
" )\n",
" overwrite_cache: bool = field(\n",
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
" )\n",
" validation_split_percentage: Optional[int] = field(\n",
" default=5,\n",
" metadata={\n",
" \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n",
" },\n",
" )\n",
" block_size: Optional[int] = field(\n",
" default=None,\n",
" metadata={\n",
" \"help\": \"Optional input sequence length after tokenization. \"\n",
" \"The training dataset will be truncated in block of this size for training. \"\n",
" \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n",
" },\n",
" )\n",
" overwrite_cache: bool = field(\n",
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
" )\n",
" preprocessing_num_workers: Optional[int] = field(\n",
" default=None,\n",
" metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n",
" )\n",
" text_column_name: Optional[str] = field(\n",
" default='text',\n",
" metadata={\"help\": \"Column containing main text data.\"},\n",
" )\n",
"\n",
" def __post_init__(self):\n",
" if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n",
" raise ValueError(\"Need either a dataset name or a training/validation file.\")\n",
" else:\n",
" if self.train_file is not None:\n",
" extension = self.train_file.split(\".\")[-1]\n",
" assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n",
" if self.validation_file is not None:\n",
" extension = self.validation_file.split(\".\")[-1]\n",
" assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n",
"\n",
"\n",
"class TrainState(train_state.TrainState):\n",
" dropout_rng: jnp.ndarray\n",
"\n",
" def replicate(self):\n",
" return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))\n",
"\n",
"\n",
"def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
" \"\"\"\n",
" Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
" Shuffle batches if `shuffle` is `True`.\n",
" \"\"\"\n",
" steps_per_epoch = len(dataset) // batch_size\n",
"\n",
" if shuffle:\n",
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
" else:\n",
" batch_idx = jnp.arange(len(dataset))\n",
"\n",
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
"\n",
" for idx in batch_idx:\n",
" batch = dataset[idx]\n",
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
"\n",
" batch = shard(batch)\n",
"\n",
" yield batch\n",
"\n",
"\n",
"def write_train_metric(summary_writer, train_metrics, train_time, step):\n",
" summary_writer.scalar(\"train_time\", train_time, step)\n",
"\n",
" train_metrics = get_metrics(train_metrics)\n",
" for key, vals in train_metrics.items():\n",
" tag = f\"train_{key}\"\n",
" for i, val in enumerate(vals):\n",
" summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n",
"\n",
"\n",
"def write_eval_metric(summary_writer, eval_metrics, step):\n",
" for metric_name, value in eval_metrics.items():\n",
" summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n",
"\n",
"\n",
"def create_learning_rate_fn(\n",
" train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n",
") -> Callable[[int], jnp.array]:\n",
" \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n",
" steps_per_epoch = train_ds_size // train_batch_size\n",
" num_train_steps = steps_per_epoch * num_train_epochs\n",
" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n",
" decay_fn = optax.linear_schedule(\n",
" init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n",
" )\n",
" schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n",
" return schedule_fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b87fb52e-7e4b-4a69-8c63-fe739331c1c5",
"metadata": {},
"outputs": [],
"source": [
"model_args = ModelArguments(\n",
" model_name_or_path=\"EleutherAI/gpt-neo-125M\",\n",
" dtype=\"bfloat16\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75c30ff3-94ec-437f-af79-17ad8429d7eb",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"data_args = DataTrainingArguments(\n",
" dataset_name=\"code_search_net\", \n",
" dataset_config_name=\"python\", \n",
" block_size=1024,\n",
" max_train_samples=10000, \n",
" max_eval_samples=1000, \n",
" preprocessing_num_workers=8\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e618fc54-07d1-4e6c-bb48-7e050044a0c8",
"metadata": {},
"outputs": [],
"source": [
"bs = 16\n",
"training_args = TrainingArguments(\n",
" num_train_epochs=1,\n",
" output_dir=\"./tmp\", \n",
" per_device_train_batch_size=bs, \n",
" per_device_eval_batch_size=bs, \n",
" learning_rate=3e-4,\n",
" weight_decay=0.1,\n",
" do_train=True,\n",
" do_eval=True,\n",
" warmup_steps=100,\n",
" push_to_hub=False,\n",
" overwrite_output_dir=True,\n",
" report_to=None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20645921-b444-4e8d-b582-08af685b42d0",
"metadata": {},
"outputs": [],
"source": [
"def save_checkpoint(model, save_dir, state):\n",
" state = jax_utils.unreplicate(state)\n",
" print(f\"SAVING CHECKPOINT IN {save_dir}\", end=\" ... \")\n",
" model.save_pretrained(\n",
" training_args.output_dir,\n",
" params=state.params,\n",
" push_to_hub=training_args.push_to_hub,\n",
" commit_message=f\"Saving weights and logs of epoch {epoch+1}\",\n",
" )\n",
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"wb\") as f:\n",
" f.write(to_bytes(state.opt_state))\n",
" with open(os.path.join(save_dir, \"training_state.json\"), \"w\") as f:\n",
" json.dump({\"step\": state.step.item()}, f)\n",
" print(\"DONE\")\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d11fd82-435d-4afd-b1eb-9914cd465a18",
"metadata": {},
"outputs": [],
"source": [
"def restore_checkpoint(save_dir, state):\n",
" print(f\"RESTORING CHECKPOINT FROM {save_dir}\", end=\" ... \")\n",
" with open(os.path.join(save_dir, \"flax_model.msgpack\"), \"rb\") as f:\n",
" params = from_bytes(state.params, f.read())\n",
"\n",
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"rb\") as f:\n",
" opt_state = from_bytes(state.opt_state, f.read())\n",
"\n",
" with open(os.path.join(save_dir, \"training_state.json\"), \"r\") as f:\n",
" training_state = json.load(f)\n",
" step = training_state[\"step\"]\n",
"\n",
" print(\"DONE\")\n",
" return params, opt_state, step"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "562e0fd9-83b5-4d51-9232-df97fca4f063",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:absl:A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.\n"
]
}
],
"source": [
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
" 10000,\n",
" 128,\n",
" training_args.num_train_epochs,\n",
" training_args.warmup_steps,\n",
" training_args.learning_rate,\n",
" )\n",
"\n",
"def decay_mask_fn(params):\n",
" flat_params = traverse_util.flatten_dict(params)\n",
" flat_mask = {\n",
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
" for path in flat_params\n",
" }\n",
" return traverse_util.unflatten_dict(flat_mask)\n",
"\n",
"optimizer = optax.adamw(\n",
" learning_rate=linear_decay_lr_schedule_fn,\n",
" b1=training_args.adam_beta1,\n",
" b2=training_args.adam_beta2,\n",
" eps=training_args.adam_epsilon,\n",
" weight_decay=training_args.weight_decay,\n",
" mask=decay_mask_fn,\n",
" )\n",
"model = FlaxAutoModelForCausalLM.from_pretrained(training_args.output_dir)\n",
"rng = jax.random.PRNGKey(training_args.seed)\n",
"rng, dropout_rng = jax.random.split(rng)\n",
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "56d2d8be-3516-4f86-b0e5-20b3a9079163",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RESTORING CHECKPOINT FROM ./tmp ... DONE\n"
]
}
],
"source": [
"params, opt_state, step = restore_checkpoint(training_args.output_dir, state)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0db5eca6-f081-4c3e-a7f1-3e080da8cc89",
"metadata": {},
"outputs": [],
"source": [
"if (\n",
" os.path.exists(training_args.output_dir)\n",
" and os.listdir(training_args.output_dir)\n",
" and training_args.do_train\n",
" and not training_args.overwrite_output_dir\n",
"):\n",
" raise ValueError(\n",
" f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n",
" \"Use --overwrite_output_dir to overcome.\"\n",
" )\n",
"\n",
"# Make one log on every process with the configuration for debugging.\n",
"logging.basicConfig(\n",
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
" level=logging.INFO,\n",
")\n",
"# Setup logging, we only want one process per machine to log things on the screen.\n",
"logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n",
"if jax.process_index() == 0:\n",
" datasets.utils.logging.set_verbosity_warning()\n",
" transformers.utils.logging.set_verbosity_info()\n",
"else:\n",
" datasets.utils.logging.set_verbosity_error()\n",
" transformers.utils.logging.set_verbosity_error()\n",
"\n",
"# Set the verbosity to info of the Transformers logger (on main process only):\n",
"logger.info(f\"Training/evaluation parameters {training_args}\")\n",
"\n",
"# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n",
"# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n",
"# (the dataset will be downloaded automatically from the datasets Hub).\n",
"#\n",
"# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n",
"# 'text' is found. You can easily tweak this behavior (see below).\n",
"#\n",
"# In distributed training, the load_dataset function guarantees that only one local process can concurrently\n",
"# download the dataset.\n",
"if data_args.dataset_name is not None:\n",
" # Downloading and loading a dataset from the hub.\n",
" dataset = load_dataset(\n",
" data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False\n",
" )\n",
"\n",
" if \"validation\" not in dataset.keys():\n",
" dataset[\"validation\"] = load_dataset(\n",
" data_args.dataset_name,\n",
" data_args.dataset_config_name,\n",
" split=f\"train[:{data_args.validation_split_percentage}%]\",\n",
" cache_dir=model_args.cache_dir,\n",
" )\n",
" dataset[\"train\"] = load_dataset(\n",
" data_args.dataset_name,\n",
" data_args.dataset_config_name,\n",
" split=f\"train[{data_args.validation_split_percentage}%:]\",\n",
" cache_dir=model_args.cache_dir,\n",
" )\n",
"else:\n",
" data_files = {}\n",
" if data_args.train_file is not None:\n",
" data_files[\"train\"] = data_args.train_file\n",
" if data_args.validation_file is not None:\n",
" data_files[\"validation\"] = data_args.validation_file\n",
" extension = data_args.train_file.split(\".\")[-1]\n",
" if extension == \"txt\":\n",
" extension = \"text\"\n",
" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)\n",
"# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n",
"# https://huggingface.co/docs/datasets/loading_datasets.html.\n",
"\n",
"# Load pretrained model and tokenizer\n",
"\n",
"# Distributed training:\n",
"# The .from_pretrained methods guarantee that only one local process can concurrently\n",
"# download model & vocab.\n",
"if model_args.config_name:\n",
" config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)\n",
"elif model_args.model_name_or_path:\n",
" config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)\n",
"else:\n",
" config = CONFIG_MAPPING[model_args.model_type]()\n",
" logger.warning(\"You are instantiating a new config instance from scratch.\")\n",
"\n",
"if model_args.tokenizer_name:\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
" )\n",
"elif model_args.model_name_or_path:\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
" )\n",
"else:\n",
" raise ValueError(\n",
" \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n",
" \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n",
" )\n",
"\n",
"if model_args.model_name_or_path:\n",
" model = FlaxAutoModelForCausalLM.from_pretrained(\n",
" model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
" )\n",
"else:\n",
" model = FlaxAutoModelForCausalLM.from_config(\n",
" config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
" )\n",
"\n",
"# Preprocessing the datasets.\n",
"# First we tokenize all the texts.\n",
"if training_args.do_train:\n",
" column_names = dataset[\"train\"].column_names\n",
"else:\n",
" column_names = dataset[\"validation\"].column_names\n",
"text_column_name = data_args.text_column_name if data_args.text_column_name in column_names else column_names[0]\n",
"\n",
"# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n",
"tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n",
"\n",
"def tokenize_function(examples):\n",
" with CaptureLogger(tok_logger) as cl:\n",
" output = tokenizer(examples[text_column_name])\n",
" # clm input could be much much longer than block_size\n",
" if \"Token indices sequence length is longer than the\" in cl.out:\n",
" tok_logger.warning(\n",
" \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\"\n",
" )\n",
" return output\n",
"\n",
"tokenized_datasets = dataset.map(\n",
" tokenize_function,\n",
" batched=True,\n",
" num_proc=data_args.preprocessing_num_workers,\n",
" remove_columns=column_names,\n",
" load_from_cache_file=not data_args.overwrite_cache,\n",
")\n",
"\n",
"if data_args.block_size is None:\n",
" block_size = tokenizer.model_max_length\n",
" if block_size > config.max_position_embeddings:\n",
" logger.warning(\n",
" f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n",
" \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n",
" )\n",
" block_size = 1024\n",
"else:\n",
" if data_args.block_size > tokenizer.model_max_length:\n",
" logger.warning(\n",
" f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n",
" f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n",
" )\n",
" block_size = min(data_args.block_size, tokenizer.model_max_length)\n",
"\n",
"# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n",
"def group_texts(examples):\n",
" # Concatenate all texts.\n",
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
" # customize this part to your needs.\n",
" total_length = (total_length // block_size) * block_size\n",
" # Split by chunks of max_len.\n",
" result = {\n",
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
" for k, t in concatenated_examples.items()\n",
" }\n",
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
" return result\n",
"\n",
"# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n",
"# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n",
"# to preprocess.\n",
"#\n",
"# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
"# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
"\n",
"lm_datasets = tokenized_datasets.map(\n",
" group_texts,\n",
" batched=True,\n",
" num_proc=data_args.preprocessing_num_workers,\n",
" load_from_cache_file=not data_args.overwrite_cache,\n",
")\n",
"\n",
"if training_args.do_train:\n",
" if \"train\" not in tokenized_datasets:\n",
" raise ValueError(\"--do_train requires a train dataset\")\n",
" train_dataset = lm_datasets[\"train\"]\n",
" if data_args.max_train_samples is not None:\n",
" train_dataset = train_dataset.select(range(data_args.max_train_samples))\n",
"\n",
"if training_args.do_eval:\n",
" if \"validation\" not in tokenized_datasets:\n",
" raise ValueError(\"--do_eval requires a validation dataset\")\n",
" eval_dataset = lm_datasets[\"validation\"]\n",
" if data_args.max_eval_samples is not None:\n",
" eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))\n",
"\n",
"# Enable tensorboard only on the master node\n",
"has_tensorboard = is_tensorboard_available()\n",
"if has_tensorboard and jax.process_index() == 0:\n",
" try:\n",
" from flax.metrics.tensorboard import SummaryWriter\n",
"\n",
" summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n",
" except ImportError as ie:\n",
" has_tensorboard = False\n",
" logger.warning(\n",
" f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n",
" )\n",
"else:\n",
" logger.warning(\n",
" \"Unable to display metrics through TensorBoard because the package is not installed: \"\n",
" \"Please run pip install tensorboard to enable.\"\n",
" )\n",
"\n",
"# Initialize our training\n",
"rng = jax.random.PRNGKey(training_args.seed)\n",
"rng, dropout_rng = jax.random.split(rng)\n",
"\n",
"# Store some constant\n",
"num_epochs = int(training_args.num_train_epochs)\n",
"train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()\n",
"eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()\n",
"steps_per_epoch = len(train_dataset) // train_batch_size\n",
"total_train_steps = steps_per_epoch * num_epochs\n",
"\n",
"# Create learning rate schedule\n",
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
" len(train_dataset),\n",
" train_batch_size,\n",
" training_args.num_train_epochs,\n",
" training_args.warmup_steps,\n",
" training_args.learning_rate,\n",
")\n",
"\n",
"# We use Optax's \"masking\" functionality to not apply weight decay\n",
"# to bias and LayerNorm scale parameters. decay_mask_fn returns a\n",
"# mask boolean with the same structure as the parameters.\n",
"# The mask is True for parameters that should be decayed.\n",
"# Note that this mask is specifically adapted for FlaxGPT2.\n",
"# For other models, one should correct the layer norm parameter naming\n",
"# accordingly.\n",
"def decay_mask_fn(params):\n",
" flat_params = traverse_util.flatten_dict(params)\n",
" flat_mask = {\n",
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
" for path in flat_params\n",
" }\n",
" return traverse_util.unflatten_dict(flat_mask)\n",
"\n",
"# create adam optimizer\n",
"adamw = optax.adamw(\n",
" learning_rate=linear_decay_lr_schedule_fn,\n",
" b1=training_args.adam_beta1,\n",
" b2=training_args.adam_beta2,\n",
" eps=training_args.adam_epsilon,\n",
" weight_decay=training_args.weight_decay,\n",
" mask=decay_mask_fn,\n",
")\n",
"\n",
"# Setup train state\n",
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)\n",
"\n",
"def loss_fn(logits, labels):\n",
" shift_logits = logits[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))\n",
" return loss.mean()\n",
"\n",
"# Define gradient update step fn\n",
"def train_step(state, batch):\n",
" dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)\n",
"\n",
" def compute_loss(params):\n",
" labels = batch.pop(\"labels\")\n",
" logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n",
" loss = loss_fn(logits, labels)\n",
" return loss\n",
"\n",
" grad_fn = jax.value_and_grad(compute_loss)\n",
" loss, grad = grad_fn(state.params)\n",
" grad = jax.lax.pmean(grad, \"batch\")\n",
"\n",
" new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)\n",
"\n",
" metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n",
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
"\n",
" return new_state, metrics\n",
"\n",
"# Define eval fn\n",
"def eval_step(params, batch):\n",
" labels = batch.pop(\"labels\")\n",
" logits = model(**batch, params=params, train=False)[0]\n",
" loss = loss_fn(logits, labels)\n",
"\n",
" # summarize metrics\n",
" metrics = {\"loss\": loss}\n",
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
" return metrics\n",
"\n",
"# Create parallel version of the train and eval step\n",
"p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n",
"p_eval_step = jax.pmap(eval_step, \"batch\")\n",
"\n",
"# Replicate the train state on each device\n",
"state = state.replicate()\n",
"\n",
"logger.info(\"***** Running training *****\")\n",
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
"logger.info(f\" Num Epochs = {num_epochs}\")\n",
"logger.info(f\" Instantaneous batch size per device = {training_args.per_device_train_batch_size}\")\n",
"logger.info(f\" Total train batch size (w. parallel & distributed) = {train_batch_size}\")\n",
"logger.info(f\" Total optimization steps = {total_train_steps}\")\n",
"\n",
"train_time = 0\n",
"train_metrics = []\n",
"epochs = tqdm(range(num_epochs), desc=f\"Epoch ... (1/{num_epochs})\", position=0)\n",
"for epoch in epochs:\n",
" # ======================== Training ================================\n",
" train_start = time.time()\n",
"\n",
" # Create sampling rng\n",
" rng, input_rng = jax.random.split(rng)\n",
"\n",
" # Generate an epoch by shuffling sampling indices from the train dataset\n",
" train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)\n",
" steps_per_epoch = len(train_dataset) // train_batch_size\n",
" # train\n",
" for step in tqdm(range(steps_per_epoch), desc=\"Training...\", position=1, leave=False):\n",
" batch = next(train_loader)\n",
" state, train_metric = p_train_step(state, batch)\n",
" train_metrics.append(train_metric)\n",
"\n",
" cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n",
"\n",
" if cur_step % training_args.logging_steps == 0 and cur_step > 0:\n",
" # Save metrics\n",
" train_metric = unreplicate(train_metric)\n",
" train_time += time.time() - train_start\n",
" if has_tensorboard and jax.process_index() == 0:\n",
" write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n",
"\n",
" epochs.write(\n",
" f\"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})\"\n",
" )\n",
"\n",
" train_metrics = []\n",
"\n",
" # ======================== Evaluating ==============================\n",
" eval_metrics = []\n",
" eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)\n",
" eval_steps = len(eval_dataset) // eval_batch_size\n",
" for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n",
" # Model forward\n",
" batch = next(eval_loader)\n",
" metrics = p_eval_step(state.params, batch)\n",
" eval_metrics.append(metrics)\n",
"\n",
" # normalize eval metrics\n",
" eval_metrics = get_metrics(eval_metrics)\n",
"\n",
" eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n",
"\n",
" try:\n",
" eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n",
" except OverflowError:\n",
" eval_metrics[\"perplexity\"] = float(\"inf\")\n",
"\n",
" # Print metrics and update progress bar\n",
" desc = f\"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})\"\n",
" epochs.write(desc)\n",
" epochs.desc = desc\n",
"\n",
" # Save metrics\n",
" if has_tensorboard and jax.process_index() == 0:\n",
" cur_step = epoch * (len(train_dataset) // train_batch_size)\n",
" write_eval_metric(summary_writer, eval_metrics, cur_step)\n",
"\n",
" # save checkpoint after each epoch and push checkpoint to the hub\n",
" if jax.process_index() == 0:\n",
" save_checkpoint(model, training_args.output_dir, state)\n",
" \n",
"\n",
"\n",
"# if __name__ == \"__main__\":\n",
"# main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f39555f2-ccf7-458c-8756-be1ff330229b",
"metadata": {},
"outputs": [],
"source": [
"save_checkpoint(model, training_args.output_dir, state)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}