mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-10-26 09:17:45 +03:00
838 lines
36 KiB
Plaintext
838 lines
36 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6d0d0fdf-def5-4c93-b1e9-4223d70c3c22",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#!/usr/bin/env python\n",
|
|
"# coding=utf-8\n",
|
|
"# Copyright 2021 The HuggingFace Team All rights reserved.\n",
|
|
"#\n",
|
|
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
|
"# you may not use this file except in compliance with the License.\n",
|
|
"# You may obtain a copy of the License at\n",
|
|
"#\n",
|
|
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
|
"#\n",
|
|
"# Unless required by applicable law or agreed to in writing, software\n",
|
|
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
|
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
|
"# See the License for the specific language governing permissions and\n",
|
|
"# limitations under the License.\n",
|
|
"\"\"\"\n",
|
|
"Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n",
|
|
"\n",
|
|
"Here is the full list of checkpoints on the hub that can be fine-tuned by this script:\n",
|
|
"https://huggingface.co/models?filter=causal-lm\n",
|
|
"\"\"\"\n",
|
|
"# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n",
|
|
"\n",
|
|
"import logging\n",
|
|
"import math\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import time\n",
|
|
"from dataclasses import dataclass, field\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Callable, Optional\n",
|
|
"\n",
|
|
"import datasets\n",
|
|
"from datasets import Dataset, load_dataset\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"\n",
|
|
"import json\n",
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import optax\n",
|
|
"import transformers\n",
|
|
"from flax import jax_utils, traverse_util\n",
|
|
"from flax.jax_utils import unreplicate\n",
|
|
"from flax.training import train_state\n",
|
|
"from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key\n",
|
|
"from flax.serialization import to_bytes, from_bytes\n",
|
|
"from transformers import (\n",
|
|
" CONFIG_MAPPING,\n",
|
|
" FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n",
|
|
" AutoConfig,\n",
|
|
" AutoTokenizer,\n",
|
|
" FlaxAutoModelForCausalLM,\n",
|
|
" HfArgumentParser,\n",
|
|
" TrainingArguments,\n",
|
|
" is_tensorboard_available,\n",
|
|
")\n",
|
|
"from transformers.testing_utils import CaptureLogger\n",
|
|
"\n",
|
|
"\n",
|
|
"logger = logging.getLogger(__name__)\n",
|
|
"\n",
|
|
"MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\n",
|
|
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n",
|
|
"\n",
|
|
"\n",
|
|
"@dataclass\n",
|
|
"class ModelArguments:\n",
|
|
" \"\"\"\n",
|
|
" Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" model_name_or_path: Optional[str] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\n",
|
|
" \"help\": \"The model checkpoint for weights initialization.\"\n",
|
|
" \"Don't set if you want to train a model from scratch.\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
" model_type: Optional[str] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n",
|
|
" )\n",
|
|
" config_name: Optional[str] = field(\n",
|
|
" default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n",
|
|
" )\n",
|
|
" tokenizer_name: Optional[str] = field(\n",
|
|
" default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n",
|
|
" )\n",
|
|
" cache_dir: Optional[str] = field(\n",
|
|
" default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n",
|
|
" )\n",
|
|
" use_fast_tokenizer: bool = field(\n",
|
|
" default=True,\n",
|
|
" metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n",
|
|
" )\n",
|
|
" dtype: Optional[str] = field(\n",
|
|
" default=\"float32\",\n",
|
|
" metadata={\n",
|
|
" \"help\": \"Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`.\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"@dataclass\n",
|
|
"class DataTrainingArguments:\n",
|
|
" \"\"\"\n",
|
|
" Arguments pertaining to what data we are going to input our model for training and eval.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" dataset_name: Optional[str] = field(\n",
|
|
" default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n",
|
|
" )\n",
|
|
" dataset_config_name: Optional[str] = field(\n",
|
|
" default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n",
|
|
" )\n",
|
|
" train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n",
|
|
" validation_file: Optional[str] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n",
|
|
" )\n",
|
|
" max_train_samples: Optional[int] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\n",
|
|
" \"help\": \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n",
|
|
" \"value if set.\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
" max_eval_samples: Optional[int] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\n",
|
|
" \"help\": \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n",
|
|
" \"value if set.\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
" overwrite_cache: bool = field(\n",
|
|
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
|
|
" )\n",
|
|
" validation_split_percentage: Optional[int] = field(\n",
|
|
" default=5,\n",
|
|
" metadata={\n",
|
|
" \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
" block_size: Optional[int] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\n",
|
|
" \"help\": \"Optional input sequence length after tokenization. \"\n",
|
|
" \"The training dataset will be truncated in block of this size for training. \"\n",
|
|
" \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n",
|
|
" },\n",
|
|
" )\n",
|
|
" overwrite_cache: bool = field(\n",
|
|
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
|
|
" )\n",
|
|
" preprocessing_num_workers: Optional[int] = field(\n",
|
|
" default=None,\n",
|
|
" metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n",
|
|
" )\n",
|
|
" text_column_name: Optional[str] = field(\n",
|
|
" default='text',\n",
|
|
" metadata={\"help\": \"Column containing main text data.\"},\n",
|
|
" )\n",
|
|
"\n",
|
|
" def __post_init__(self):\n",
|
|
" if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n",
|
|
" raise ValueError(\"Need either a dataset name or a training/validation file.\")\n",
|
|
" else:\n",
|
|
" if self.train_file is not None:\n",
|
|
" extension = self.train_file.split(\".\")[-1]\n",
|
|
" assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n",
|
|
" if self.validation_file is not None:\n",
|
|
" extension = self.validation_file.split(\".\")[-1]\n",
|
|
" assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n",
|
|
"\n",
|
|
"\n",
|
|
"class TrainState(train_state.TrainState):\n",
|
|
" dropout_rng: jnp.ndarray\n",
|
|
"\n",
|
|
" def replicate(self):\n",
|
|
" return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))\n",
|
|
"\n",
|
|
"\n",
|
|
"def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
|
|
" \"\"\"\n",
|
|
" Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
|
|
" Shuffle batches if `shuffle` is `True`.\n",
|
|
" \"\"\"\n",
|
|
" steps_per_epoch = len(dataset) // batch_size\n",
|
|
"\n",
|
|
" if shuffle:\n",
|
|
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
|
|
" else:\n",
|
|
" batch_idx = jnp.arange(len(dataset))\n",
|
|
"\n",
|
|
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
|
|
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
|
|
"\n",
|
|
" for idx in batch_idx:\n",
|
|
" batch = dataset[idx]\n",
|
|
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
|
|
"\n",
|
|
" batch = shard(batch)\n",
|
|
"\n",
|
|
" yield batch\n",
|
|
"\n",
|
|
"\n",
|
|
"def write_train_metric(summary_writer, train_metrics, train_time, step):\n",
|
|
" summary_writer.scalar(\"train_time\", train_time, step)\n",
|
|
"\n",
|
|
" train_metrics = get_metrics(train_metrics)\n",
|
|
" for key, vals in train_metrics.items():\n",
|
|
" tag = f\"train_{key}\"\n",
|
|
" for i, val in enumerate(vals):\n",
|
|
" summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def write_eval_metric(summary_writer, eval_metrics, step):\n",
|
|
" for metric_name, value in eval_metrics.items():\n",
|
|
" summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_learning_rate_fn(\n",
|
|
" train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n",
|
|
") -> Callable[[int], jnp.array]:\n",
|
|
" \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n",
|
|
" steps_per_epoch = train_ds_size // train_batch_size\n",
|
|
" num_train_steps = steps_per_epoch * num_train_epochs\n",
|
|
" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n",
|
|
" decay_fn = optax.linear_schedule(\n",
|
|
" init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n",
|
|
" )\n",
|
|
" schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n",
|
|
" return schedule_fn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b87fb52e-7e4b-4a69-8c63-fe739331c1c5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model_args = ModelArguments(\n",
|
|
" model_name_or_path=\"EleutherAI/gpt-neo-125M\",\n",
|
|
" dtype=\"bfloat16\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "75c30ff3-94ec-437f-af79-17ad8429d7eb",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_args = DataTrainingArguments(\n",
|
|
" dataset_name=\"code_search_net\", \n",
|
|
" dataset_config_name=\"python\", \n",
|
|
" block_size=1024,\n",
|
|
" max_train_samples=10000, \n",
|
|
" max_eval_samples=1000, \n",
|
|
" preprocessing_num_workers=8\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e618fc54-07d1-4e6c-bb48-7e050044a0c8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"bs = 16\n",
|
|
"training_args = TrainingArguments(\n",
|
|
" num_train_epochs=1,\n",
|
|
" output_dir=\"./tmp\", \n",
|
|
" per_device_train_batch_size=bs, \n",
|
|
" per_device_eval_batch_size=bs, \n",
|
|
" learning_rate=3e-4,\n",
|
|
" weight_decay=0.1,\n",
|
|
" do_train=True,\n",
|
|
" do_eval=True,\n",
|
|
" warmup_steps=100,\n",
|
|
" push_to_hub=False,\n",
|
|
" overwrite_output_dir=True,\n",
|
|
" report_to=None\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "20645921-b444-4e8d-b582-08af685b42d0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def save_checkpoint(model, save_dir, state):\n",
|
|
" state = jax_utils.unreplicate(state)\n",
|
|
" print(f\"SAVING CHECKPOINT IN {save_dir}\", end=\" ... \")\n",
|
|
" model.save_pretrained(\n",
|
|
" training_args.output_dir,\n",
|
|
" params=state.params,\n",
|
|
" push_to_hub=training_args.push_to_hub,\n",
|
|
" commit_message=f\"Saving weights and logs of epoch {epoch+1}\",\n",
|
|
" )\n",
|
|
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"wb\") as f:\n",
|
|
" f.write(to_bytes(state.opt_state))\n",
|
|
" with open(os.path.join(save_dir, \"training_state.json\"), \"w\") as f:\n",
|
|
" json.dump({\"step\": state.step.item()}, f)\n",
|
|
" print(\"DONE\")\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5d11fd82-435d-4afd-b1eb-9914cd465a18",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def restore_checkpoint(save_dir, state):\n",
|
|
" print(f\"RESTORING CHECKPOINT FROM {save_dir}\", end=\" ... \")\n",
|
|
" with open(os.path.join(save_dir, \"flax_model.msgpack\"), \"rb\") as f:\n",
|
|
" params = from_bytes(state.params, f.read())\n",
|
|
"\n",
|
|
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"rb\") as f:\n",
|
|
" opt_state = from_bytes(state.opt_state, f.read())\n",
|
|
"\n",
|
|
" with open(os.path.join(save_dir, \"training_state.json\"), \"r\") as f:\n",
|
|
" training_state = json.load(f)\n",
|
|
" step = training_state[\"step\"]\n",
|
|
"\n",
|
|
" print(\"DONE\")\n",
|
|
" return params, opt_state, step"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "562e0fd9-83b5-4d51-9232-df97fca4f063",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"INFO:absl:A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
|
|
" 10000,\n",
|
|
" 128,\n",
|
|
" training_args.num_train_epochs,\n",
|
|
" training_args.warmup_steps,\n",
|
|
" training_args.learning_rate,\n",
|
|
" )\n",
|
|
"\n",
|
|
"def decay_mask_fn(params):\n",
|
|
" flat_params = traverse_util.flatten_dict(params)\n",
|
|
" flat_mask = {\n",
|
|
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
|
|
" for path in flat_params\n",
|
|
" }\n",
|
|
" return traverse_util.unflatten_dict(flat_mask)\n",
|
|
"\n",
|
|
"optimizer = optax.adamw(\n",
|
|
" learning_rate=linear_decay_lr_schedule_fn,\n",
|
|
" b1=training_args.adam_beta1,\n",
|
|
" b2=training_args.adam_beta2,\n",
|
|
" eps=training_args.adam_epsilon,\n",
|
|
" weight_decay=training_args.weight_decay,\n",
|
|
" mask=decay_mask_fn,\n",
|
|
" )\n",
|
|
"model = FlaxAutoModelForCausalLM.from_pretrained(training_args.output_dir)\n",
|
|
"rng = jax.random.PRNGKey(training_args.seed)\n",
|
|
"rng, dropout_rng = jax.random.split(rng)\n",
|
|
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "56d2d8be-3516-4f86-b0e5-20b3a9079163",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"RESTORING CHECKPOINT FROM ./tmp ... DONE\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"params, opt_state, step = restore_checkpoint(training_args.output_dir, state)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0db5eca6-f081-4c3e-a7f1-3e080da8cc89",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if (\n",
|
|
" os.path.exists(training_args.output_dir)\n",
|
|
" and os.listdir(training_args.output_dir)\n",
|
|
" and training_args.do_train\n",
|
|
" and not training_args.overwrite_output_dir\n",
|
|
"):\n",
|
|
" raise ValueError(\n",
|
|
" f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n",
|
|
" \"Use --overwrite_output_dir to overcome.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Make one log on every process with the configuration for debugging.\n",
|
|
"logging.basicConfig(\n",
|
|
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
|
|
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
|
|
" level=logging.INFO,\n",
|
|
")\n",
|
|
"# Setup logging, we only want one process per machine to log things on the screen.\n",
|
|
"logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n",
|
|
"if jax.process_index() == 0:\n",
|
|
" datasets.utils.logging.set_verbosity_warning()\n",
|
|
" transformers.utils.logging.set_verbosity_info()\n",
|
|
"else:\n",
|
|
" datasets.utils.logging.set_verbosity_error()\n",
|
|
" transformers.utils.logging.set_verbosity_error()\n",
|
|
"\n",
|
|
"# Set the verbosity to info of the Transformers logger (on main process only):\n",
|
|
"logger.info(f\"Training/evaluation parameters {training_args}\")\n",
|
|
"\n",
|
|
"# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n",
|
|
"# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n",
|
|
"# (the dataset will be downloaded automatically from the datasets Hub).\n",
|
|
"#\n",
|
|
"# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n",
|
|
"# 'text' is found. You can easily tweak this behavior (see below).\n",
|
|
"#\n",
|
|
"# In distributed training, the load_dataset function guarantees that only one local process can concurrently\n",
|
|
"# download the dataset.\n",
|
|
"if data_args.dataset_name is not None:\n",
|
|
" # Downloading and loading a dataset from the hub.\n",
|
|
" dataset = load_dataset(\n",
|
|
" data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False\n",
|
|
" )\n",
|
|
"\n",
|
|
" if \"validation\" not in dataset.keys():\n",
|
|
" dataset[\"validation\"] = load_dataset(\n",
|
|
" data_args.dataset_name,\n",
|
|
" data_args.dataset_config_name,\n",
|
|
" split=f\"train[:{data_args.validation_split_percentage}%]\",\n",
|
|
" cache_dir=model_args.cache_dir,\n",
|
|
" )\n",
|
|
" dataset[\"train\"] = load_dataset(\n",
|
|
" data_args.dataset_name,\n",
|
|
" data_args.dataset_config_name,\n",
|
|
" split=f\"train[{data_args.validation_split_percentage}%:]\",\n",
|
|
" cache_dir=model_args.cache_dir,\n",
|
|
" )\n",
|
|
"else:\n",
|
|
" data_files = {}\n",
|
|
" if data_args.train_file is not None:\n",
|
|
" data_files[\"train\"] = data_args.train_file\n",
|
|
" if data_args.validation_file is not None:\n",
|
|
" data_files[\"validation\"] = data_args.validation_file\n",
|
|
" extension = data_args.train_file.split(\".\")[-1]\n",
|
|
" if extension == \"txt\":\n",
|
|
" extension = \"text\"\n",
|
|
" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)\n",
|
|
"# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n",
|
|
"# https://huggingface.co/docs/datasets/loading_datasets.html.\n",
|
|
"\n",
|
|
"# Load pretrained model and tokenizer\n",
|
|
"\n",
|
|
"# Distributed training:\n",
|
|
"# The .from_pretrained methods guarantee that only one local process can concurrently\n",
|
|
"# download model & vocab.\n",
|
|
"if model_args.config_name:\n",
|
|
" config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)\n",
|
|
"elif model_args.model_name_or_path:\n",
|
|
" config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)\n",
|
|
"else:\n",
|
|
" config = CONFIG_MAPPING[model_args.model_type]()\n",
|
|
" logger.warning(\"You are instantiating a new config instance from scratch.\")\n",
|
|
"\n",
|
|
"if model_args.tokenizer_name:\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(\n",
|
|
" model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
|
|
" )\n",
|
|
"elif model_args.model_name_or_path:\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(\n",
|
|
" model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
|
|
" )\n",
|
|
"else:\n",
|
|
" raise ValueError(\n",
|
|
" \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n",
|
|
" \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"if model_args.model_name_or_path:\n",
|
|
" model = FlaxAutoModelForCausalLM.from_pretrained(\n",
|
|
" model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
|
|
" )\n",
|
|
"else:\n",
|
|
" model = FlaxAutoModelForCausalLM.from_config(\n",
|
|
" config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Preprocessing the datasets.\n",
|
|
"# First we tokenize all the texts.\n",
|
|
"if training_args.do_train:\n",
|
|
" column_names = dataset[\"train\"].column_names\n",
|
|
"else:\n",
|
|
" column_names = dataset[\"validation\"].column_names\n",
|
|
"text_column_name = data_args.text_column_name if data_args.text_column_name in column_names else column_names[0]\n",
|
|
"\n",
|
|
"# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n",
|
|
"tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n",
|
|
"\n",
|
|
"def tokenize_function(examples):\n",
|
|
" with CaptureLogger(tok_logger) as cl:\n",
|
|
" output = tokenizer(examples[text_column_name])\n",
|
|
" # clm input could be much much longer than block_size\n",
|
|
" if \"Token indices sequence length is longer than the\" in cl.out:\n",
|
|
" tok_logger.warning(\n",
|
|
" \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\"\n",
|
|
" )\n",
|
|
" return output\n",
|
|
"\n",
|
|
"tokenized_datasets = dataset.map(\n",
|
|
" tokenize_function,\n",
|
|
" batched=True,\n",
|
|
" num_proc=data_args.preprocessing_num_workers,\n",
|
|
" remove_columns=column_names,\n",
|
|
" load_from_cache_file=not data_args.overwrite_cache,\n",
|
|
")\n",
|
|
"\n",
|
|
"if data_args.block_size is None:\n",
|
|
" block_size = tokenizer.model_max_length\n",
|
|
" if block_size > config.max_position_embeddings:\n",
|
|
" logger.warning(\n",
|
|
" f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n",
|
|
" \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n",
|
|
" )\n",
|
|
" block_size = 1024\n",
|
|
"else:\n",
|
|
" if data_args.block_size > tokenizer.model_max_length:\n",
|
|
" logger.warning(\n",
|
|
" f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n",
|
|
" f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n",
|
|
" )\n",
|
|
" block_size = min(data_args.block_size, tokenizer.model_max_length)\n",
|
|
"\n",
|
|
"# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n",
|
|
"def group_texts(examples):\n",
|
|
" # Concatenate all texts.\n",
|
|
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
|
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
|
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
|
" # customize this part to your needs.\n",
|
|
" total_length = (total_length // block_size) * block_size\n",
|
|
" # Split by chunks of max_len.\n",
|
|
" result = {\n",
|
|
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
|
" for k, t in concatenated_examples.items()\n",
|
|
" }\n",
|
|
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
|
" return result\n",
|
|
"\n",
|
|
"# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n",
|
|
"# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n",
|
|
"# to preprocess.\n",
|
|
"#\n",
|
|
"# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
|
|
"# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
|
|
"\n",
|
|
"lm_datasets = tokenized_datasets.map(\n",
|
|
" group_texts,\n",
|
|
" batched=True,\n",
|
|
" num_proc=data_args.preprocessing_num_workers,\n",
|
|
" load_from_cache_file=not data_args.overwrite_cache,\n",
|
|
")\n",
|
|
"\n",
|
|
"if training_args.do_train:\n",
|
|
" if \"train\" not in tokenized_datasets:\n",
|
|
" raise ValueError(\"--do_train requires a train dataset\")\n",
|
|
" train_dataset = lm_datasets[\"train\"]\n",
|
|
" if data_args.max_train_samples is not None:\n",
|
|
" train_dataset = train_dataset.select(range(data_args.max_train_samples))\n",
|
|
"\n",
|
|
"if training_args.do_eval:\n",
|
|
" if \"validation\" not in tokenized_datasets:\n",
|
|
" raise ValueError(\"--do_eval requires a validation dataset\")\n",
|
|
" eval_dataset = lm_datasets[\"validation\"]\n",
|
|
" if data_args.max_eval_samples is not None:\n",
|
|
" eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))\n",
|
|
"\n",
|
|
"# Enable tensorboard only on the master node\n",
|
|
"has_tensorboard = is_tensorboard_available()\n",
|
|
"if has_tensorboard and jax.process_index() == 0:\n",
|
|
" try:\n",
|
|
" from flax.metrics.tensorboard import SummaryWriter\n",
|
|
"\n",
|
|
" summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n",
|
|
" except ImportError as ie:\n",
|
|
" has_tensorboard = False\n",
|
|
" logger.warning(\n",
|
|
" f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n",
|
|
" )\n",
|
|
"else:\n",
|
|
" logger.warning(\n",
|
|
" \"Unable to display metrics through TensorBoard because the package is not installed: \"\n",
|
|
" \"Please run pip install tensorboard to enable.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Initialize our training\n",
|
|
"rng = jax.random.PRNGKey(training_args.seed)\n",
|
|
"rng, dropout_rng = jax.random.split(rng)\n",
|
|
"\n",
|
|
"# Store some constant\n",
|
|
"num_epochs = int(training_args.num_train_epochs)\n",
|
|
"train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()\n",
|
|
"eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()\n",
|
|
"steps_per_epoch = len(train_dataset) // train_batch_size\n",
|
|
"total_train_steps = steps_per_epoch * num_epochs\n",
|
|
"\n",
|
|
"# Create learning rate schedule\n",
|
|
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
|
|
" len(train_dataset),\n",
|
|
" train_batch_size,\n",
|
|
" training_args.num_train_epochs,\n",
|
|
" training_args.warmup_steps,\n",
|
|
" training_args.learning_rate,\n",
|
|
")\n",
|
|
"\n",
|
|
"# We use Optax's \"masking\" functionality to not apply weight decay\n",
|
|
"# to bias and LayerNorm scale parameters. decay_mask_fn returns a\n",
|
|
"# mask boolean with the same structure as the parameters.\n",
|
|
"# The mask is True for parameters that should be decayed.\n",
|
|
"# Note that this mask is specifically adapted for FlaxGPT2.\n",
|
|
"# For other models, one should correct the layer norm parameter naming\n",
|
|
"# accordingly.\n",
|
|
"def decay_mask_fn(params):\n",
|
|
" flat_params = traverse_util.flatten_dict(params)\n",
|
|
" flat_mask = {\n",
|
|
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
|
|
" for path in flat_params\n",
|
|
" }\n",
|
|
" return traverse_util.unflatten_dict(flat_mask)\n",
|
|
"\n",
|
|
"# create adam optimizer\n",
|
|
"adamw = optax.adamw(\n",
|
|
" learning_rate=linear_decay_lr_schedule_fn,\n",
|
|
" b1=training_args.adam_beta1,\n",
|
|
" b2=training_args.adam_beta2,\n",
|
|
" eps=training_args.adam_epsilon,\n",
|
|
" weight_decay=training_args.weight_decay,\n",
|
|
" mask=decay_mask_fn,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Setup train state\n",
|
|
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)\n",
|
|
"\n",
|
|
"def loss_fn(logits, labels):\n",
|
|
" shift_logits = logits[..., :-1, :]\n",
|
|
" shift_labels = labels[..., 1:]\n",
|
|
" loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))\n",
|
|
" return loss.mean()\n",
|
|
"\n",
|
|
"# Define gradient update step fn\n",
|
|
"def train_step(state, batch):\n",
|
|
" dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)\n",
|
|
"\n",
|
|
" def compute_loss(params):\n",
|
|
" labels = batch.pop(\"labels\")\n",
|
|
" logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n",
|
|
" loss = loss_fn(logits, labels)\n",
|
|
" return loss\n",
|
|
"\n",
|
|
" grad_fn = jax.value_and_grad(compute_loss)\n",
|
|
" loss, grad = grad_fn(state.params)\n",
|
|
" grad = jax.lax.pmean(grad, \"batch\")\n",
|
|
"\n",
|
|
" new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)\n",
|
|
"\n",
|
|
" metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n",
|
|
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
|
|
"\n",
|
|
" return new_state, metrics\n",
|
|
"\n",
|
|
"# Define eval fn\n",
|
|
"def eval_step(params, batch):\n",
|
|
" labels = batch.pop(\"labels\")\n",
|
|
" logits = model(**batch, params=params, train=False)[0]\n",
|
|
" loss = loss_fn(logits, labels)\n",
|
|
"\n",
|
|
" # summarize metrics\n",
|
|
" metrics = {\"loss\": loss}\n",
|
|
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
|
|
" return metrics\n",
|
|
"\n",
|
|
"# Create parallel version of the train and eval step\n",
|
|
"p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n",
|
|
"p_eval_step = jax.pmap(eval_step, \"batch\")\n",
|
|
"\n",
|
|
"# Replicate the train state on each device\n",
|
|
"state = state.replicate()\n",
|
|
"\n",
|
|
"logger.info(\"***** Running training *****\")\n",
|
|
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
|
|
"logger.info(f\" Num Epochs = {num_epochs}\")\n",
|
|
"logger.info(f\" Instantaneous batch size per device = {training_args.per_device_train_batch_size}\")\n",
|
|
"logger.info(f\" Total train batch size (w. parallel & distributed) = {train_batch_size}\")\n",
|
|
"logger.info(f\" Total optimization steps = {total_train_steps}\")\n",
|
|
"\n",
|
|
"train_time = 0\n",
|
|
"train_metrics = []\n",
|
|
"epochs = tqdm(range(num_epochs), desc=f\"Epoch ... (1/{num_epochs})\", position=0)\n",
|
|
"for epoch in epochs:\n",
|
|
" # ======================== Training ================================\n",
|
|
" train_start = time.time()\n",
|
|
"\n",
|
|
" # Create sampling rng\n",
|
|
" rng, input_rng = jax.random.split(rng)\n",
|
|
"\n",
|
|
" # Generate an epoch by shuffling sampling indices from the train dataset\n",
|
|
" train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)\n",
|
|
" steps_per_epoch = len(train_dataset) // train_batch_size\n",
|
|
" # train\n",
|
|
" for step in tqdm(range(steps_per_epoch), desc=\"Training...\", position=1, leave=False):\n",
|
|
" batch = next(train_loader)\n",
|
|
" state, train_metric = p_train_step(state, batch)\n",
|
|
" train_metrics.append(train_metric)\n",
|
|
"\n",
|
|
" cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n",
|
|
"\n",
|
|
" if cur_step % training_args.logging_steps == 0 and cur_step > 0:\n",
|
|
" # Save metrics\n",
|
|
" train_metric = unreplicate(train_metric)\n",
|
|
" train_time += time.time() - train_start\n",
|
|
" if has_tensorboard and jax.process_index() == 0:\n",
|
|
" write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n",
|
|
"\n",
|
|
" epochs.write(\n",
|
|
" f\"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" train_metrics = []\n",
|
|
"\n",
|
|
" # ======================== Evaluating ==============================\n",
|
|
" eval_metrics = []\n",
|
|
" eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)\n",
|
|
" eval_steps = len(eval_dataset) // eval_batch_size\n",
|
|
" for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n",
|
|
" # Model forward\n",
|
|
" batch = next(eval_loader)\n",
|
|
" metrics = p_eval_step(state.params, batch)\n",
|
|
" eval_metrics.append(metrics)\n",
|
|
"\n",
|
|
" # normalize eval metrics\n",
|
|
" eval_metrics = get_metrics(eval_metrics)\n",
|
|
"\n",
|
|
" eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n",
|
|
"\n",
|
|
" try:\n",
|
|
" eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n",
|
|
" except OverflowError:\n",
|
|
" eval_metrics[\"perplexity\"] = float(\"inf\")\n",
|
|
"\n",
|
|
" # Print metrics and update progress bar\n",
|
|
" desc = f\"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})\"\n",
|
|
" epochs.write(desc)\n",
|
|
" epochs.desc = desc\n",
|
|
"\n",
|
|
" # Save metrics\n",
|
|
" if has_tensorboard and jax.process_index() == 0:\n",
|
|
" cur_step = epoch * (len(train_dataset) // train_batch_size)\n",
|
|
" write_eval_metric(summary_writer, eval_metrics, cur_step)\n",
|
|
"\n",
|
|
" # save checkpoint after each epoch and push checkpoint to the hub\n",
|
|
" if jax.process_index() == 0:\n",
|
|
" save_checkpoint(model, training_args.output_dir, state)\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
"# if __name__ == \"__main__\":\n",
|
|
"# main()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f39555f2-ccf7-458c-8756-be1ff330229b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_checkpoint(model, training_args.output_dir, state)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|