diff --git a/code-clippy-app.ipynb b/code-clippy-app.ipynb new file mode 100644 index 0000000..0194c78 --- /dev/null +++ b/code-clippy-app.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "333a4f54-e120-4969-8adf-32b98655ff41", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-07-18 23:45:31.083087: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-07-18 23:45:31.083131: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d875c6bc-8d97-4e03-9ee4-a4bd47f191bf", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", + " pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n", + "All Flax model weights were used when initializing GPTNeoForCausalLM.\n", + "\n", + "Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['transformer.h.4.attn.attention.bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.10.attn.attention.bias', 'lm_head.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model_name = \"/home/shared/models/gpt-code-clippy-125M-apps-lr-adam1e-4-bs128/ckpt-1633/\"\n", + "model = AutoModelForCausalLM.from_pretrained(model_name, from_flax=True)\n", + "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e6ebb017-e784-4311-a795-4bacd2263d19", + "metadata": {}, + "outputs": [], + "source": [ + "def format_input(question, starter_code=\"\"):\n", + " answer_type = \"\\nUse Call-Based format\\n\" if starter_code else \"\\nUse Standard Input format\\n\"\n", + " return f\"\\nQUESTION:\\n{question}\\n{starter_code}\\n{answer_type}\\nANSWER:\\n\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b00ad3ab-3086-48f0-940a-3129a7dff30a", + "metadata": {}, + "outputs": [], + "source": [ + "def format_outputs(text):\n", + " formatted_text =f'''\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
{text}
\n", + " \n", + " '''\n", + " return formatted_text" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "28abdab5-962e-47af-83b2-4f9c491ba705", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_solution(question, starter_code=\"\", temperature=1., num_beams=1):\n", + " prompt = format_input(question, starter_code)\n", + " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + " start = len(input_ids[0])\n", + " output = model.generate(\n", + " input_ids,\n", + " max_length=start+200,\n", + " do_sample=True,\n", + " top_p=0.95,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " early_stopping=True,\n", + " temperature=1.,\n", + " num_beams=int(num_beams),\n", + " no_repeat_ngram_size=None,\n", + " repetition_penalty=None,\n", + " num_return_sequences=None,\n", + " )\n", + " \n", + " return format_outputs(tokenizer.decode(output[0][start:]).strip())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a156ff76-7e50-40c9-8b3e-e9ddeb450501", + "metadata": {}, + "outputs": [], + "source": [ + "_EXAMPLES = [\n", + " [\n", + " \"\"\"\n", + "Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row.\n", + "For Example:\n", + "```python\n", + "[\n", + " [1, 2, 3, 4, 5], # minimum value of row is 1\n", + " [5, 6, 7, 8, 9], # minimum value of row is 5\n", + " [20, 21, 34, 56, 100] # minimum value of row is 20\n", + "]\n", + "```\n", + "So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26`\n", + " \"\"\",\n", + " \"\",\n", + " 0.8,\n", + " ],\n", + " [\n", + " \"\"\"\n", + "# Personalized greeting\n", + "\n", + "Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`.\n", + " \"\"\",\n", + " \"\"\"\n", + "Use conditionals to return the proper message:\n", + "\n", + "case| return\n", + "--- | ---\n", + "name equals owner | 'Hello boss'\n", + "otherwise | 'Hello guest'\n", + "def greet(name, owner):\n", + " \"\"\",\n", + " 0.8,\n", + " ]\n", + "] " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "fb3d90fc-6932-4343-9c86-70ae94ca95aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running locally at: http://127.0.0.1:7861/\n", + "This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n", + "Running on External URL: https://34711.gradio.app\n", + "Interface loading below...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(,\n", + " 'http://127.0.0.1:7861/',\n", + " 'https://34711.gradio.app')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs = [\n", + " gr.inputs.Textbox(placeholder=\"Define a problem here...\", lines=7),\n", + " gr.inputs.Textbox(placeholder=\"Provide optional starter code...\", lines=3),\n", + " gr.inputs.Slider(0.5, 1.5, 0.1, default=0.8, label=\"Temperature\"),\n", + " gr.inputs.Slider(1,4,1,default=1, label=\"Beam size\")\n", + "]\n", + "\n", + "outputs = [\n", + " gr.outputs.HTML(label=\"Solution\")\n", + "]\n", + "\n", + "gr.Interface(\n", + " generate_solution, \n", + " inputs=inputs, \n", + " outputs=outputs,\n", + " title=\"Code Clippy: Problem Solver\",\n", + " examples=_EXAMPLES,\n", + ").launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "737cd94c-5286-4832-9611-06e6f2a89357", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demo_backup/app.py b/demo_backup/app.py new file mode 100644 index 0000000..d116c5d --- /dev/null +++ b/demo_backup/app.py @@ -0,0 +1,145 @@ +import gradio as gr + +from rich.console import Console +from rich.syntax import Syntax +from transformers import AutoModelForCausalLM, AutoTokenizer +import requests +import json +import webbrowser + +# model_name = "flax-community/gpt-code-clippy-1.3B-apps-alldata" +model_name = "flax-community/gpt-code-clippy-125M-apps-alldata" +model = AutoModelForCausalLM.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token + +console = Console(record=True) + + +def format_input(question, starter_code=""): + answer_type = ( + "\ +Use Call-Based format\ +" if starter_code else "\ +Use Standard Input format\ +" + ) + return f"\ +QUESTION:\ +{question}\ +{starter_code}\ +{answer_type}\ +ANSWER:\ +" + + +def format_outputs(text): + formatted_text = Syntax( + text, "python", line_numbers=True, indent_guides=True, word_wrap=True + ) + console.print(formatted_text) + + return console.export_html(inline_styles=True) + + +def generate_solution(question, starter_code="", temperature=1.0, num_beams=1): + prompt = format_input(question, starter_code) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + start = len(input_ids[0]) + output = model.generate( + input_ids, + max_length=start + 200, + do_sample=True, + top_p=0.95, + pad_token_id=tokenizer.pad_token_id, + early_stopping=True, + temperature=temperature, + num_beams=int(num_beams), + no_repeat_ngram_size=None, + repetition_penalty=None, + num_return_sequences=None, + ) + + return format_outputs( + tokenizer.decode(output[0][start:], skip_special_tokens=True).strip() + ) + + +_EXAMPLES = [ + [ + """ +Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row. +For Example: +```python +[ + [1, 2, 3, 4, 5], # minimum value of row is 1 + [5, 6, 7, 8, 9], # minimum value of row is 5 + [20, 21, 34, 56, 100] # minimum value of row is 20 +] +``` +So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26` + """, + "", + 0.8, + ], + [ + """ +# Personalized greeting + +Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`. + """, + """ +Use conditionals to return the proper message: + +case| return +--- | --- +name equals owner | 'Hello boss' +otherwise | 'Hello guest' +def greet(name, owner): + """, + 0.8, + ], +] + + +inputs = [ + gr.inputs.Textbox(placeholder="Define a problem here...", lines=7), + gr.inputs.Textbox(placeholder="Provide optional starter code...", lines=3), + gr.inputs.Slider(0.5, 1.5, 0.1, default=0.8, label="Temperature"), + gr.inputs.Slider(1, 4, 1, default=1, label="Beam size"), + gr.inputs.Textbox(lines=1, label="Your GitHub API token") +] + +outputs = [gr.outputs.HTML(label="Solution")] +print(outputs) + +# adding carbon support + +GITHUB_API="https://api.github.com" +API_TOKEN=gr.inputs.Textbox(label="Your GitHub API token") +#form a request URL +url=GITHUB_API+"/gists" + +#print headers,parameters,payload +headers={'Authorization':'token %s'%API_TOKEN} +params={'scope':'gist'} +payload={outputs} + + + +res=requests.post(url,headers=headers,params=params,data=json.dumps(payload)) + + +col = st.beta_columns([2, 4]) +if col.button("Create a 'carbon' copy"): + carbon_url='https://carbon.now.sh/'+res.text.split(',')[0].split('/')[-1][:-1] + webbrowser.open_new(carbon_url) + + +gr.Interface( + generate_solution, + inputs=inputs, + outputs=outputs, + title="Code Clippy: Problem Solver", + examples=_EXAMPLES, +).launch(share=False) diff --git a/finetune_apps.sh b/finetune_apps.sh index c4ee378..a1ef8fc 100644 --- a/finetune_apps.sh +++ b/finetune_apps.sh @@ -1,6 +1,6 @@ #! /bin/bash ./run_clm_apps.py \ - --output_dir /home/shared/models/gpt-code-clippy-1.3B-apps \ + --output_dir /home/shared/models/gpt-code-clippy-1.3B-apps-3 \ --model_name_or_path EleutherAI/gpt-neo-1.3B \ --dataset_name ./apps.py \ --dataset_config_name formatted \ @@ -24,11 +24,13 @@ --skip_memory_metrics="False" \ --save_steps="1000" \ --save_strategy epoch \ - --save_total_limit 2 \ + --save_total_limit="None" \ --gradient_accumulation_steps 1 \ --adafactor true \ - --all_data true \ + --all_data false \ --seed 842 \ + --save_optimizer false \ + --max_eval_samples 20000 # --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \ # --max_train_samples="10000" \ - # --max_eval_samples="1000" + diff --git a/generate_apps.ipynb b/generate_apps.ipynb index 81850f1..adccda3 100644 --- a/generate_apps.ipynb +++ b/generate_apps.ipynb @@ -10,8 +10,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-07-18 07:38:07.042553: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", - "2021-07-18 07:38:07.042596: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + "2021-07-18 19:07:02.959520: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-07-18 19:07:02.959564: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], @@ -36,24 +36,38 @@ } ], "source": [ - "dataset = load_dataset(\"/home/arto/datasets/datasets/apps/apps.py\", \"formatted\", split=\"train\")" + "dataset = load_dataset(\"/home/arto/datasets/datasets/apps/apps.py\", \"formatted\", split=\"test\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "id": "3811070c-c6a0-4a84-9362-cf7de0d5bd75", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", + " pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n", + "All Flax model weights were used when initializing GPTNeoForCausalLM.\n", + "\n", + "Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['transformer.h.0.attn.attention.bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.2.attn.attention.masked_bias', 'lm_head.weight', 'transformer.h.10.attn.attention.bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.8.attn.attention.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], "source": [ - "# model = AutoModelForCausalLM.from_pretrained(\"/home/arto/gpt-code-clippy-lr1e-4-bs1024-f/ckpt-80000\",from_flax=True)\n", + "model = AutoModelForCausalLM.from_pretrained(\"/home/shared/models/gpt-code-clippy-125M-apps-lr-5e-5/ckpt-8169/\",from_flax=True)\n", + "# model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 36, "id": "830a374f-a70d-466d-a766-3671e5c11765", "metadata": {}, "outputs": [ @@ -63,85 +77,70 @@ "text": [ "id :\n", "\n", - "6386\n", + "33905\n", "==========\n", "question :\n", "\n", "\n", "QUESTION:\n", - "Given an array arr.  You can choose a set of integers and remove all the occurrences of these integers in the array.\n", - "Return the minimum size of the set so that at least half of the integers of the array are removed.\n", - " \n", - "Example 1:\n", - "Input: arr = [3,3,3,3,5,5,5,2,2,7]\n", - "Output: 2\n", - "Explanation: Choosing {3,7} will make the new array [5,5,5,2,2] which has size 5 (i.e equal to half of the size of the old array).\n", - "Possible sets of size 2 are {3,5},{3,2},{5,2}.\n", - "Choosing set {2,7} is not possible as it will make the new array [3,3,3,3,5,5,5] which has size greater than half of the size of the old array.\n", + "There are N empty boxes arranged in a row from left to right.\n", + "The integer i is written on the i-th box from the left (1 \\leq i \\leq N).\n", + "For each of these boxes, Snuke can choose either to put a ball in it or to put nothing in it.\n", + "We say a set of choices to put a ball or not in the boxes is good when the following condition is satisfied:\n", + " - For every integer i between 1 and N (inclusive), the total number of balls contained in the boxes with multiples of i written on them is congruent to a_i modulo 2.\n", + "Does there exist a good set of choices? If the answer is yes, find one good set of choices.\n", "\n", - "Example 2:\n", - "Input: arr = [7,7,7,7,7,7]\n", - "Output: 1\n", - "Explanation: The only possible set you can choose is {7}. This will make the new array empty.\n", + "-----Constraints-----\n", + " - All values in input are integers.\n", + " - 1 \\leq N \\leq 2 \\times 10^5\n", + " - a_i is 0 or 1.\n", "\n", - "Example 3:\n", - "Input: arr = [1,9]\n", - "Output: 1\n", + "-----Input-----\n", + "Input is given from Standard Input in the following format:\n", + "N\n", + "a_1 a_2 ... a_N\n", "\n", - "Example 4:\n", - "Input: arr = [1000,1000,3,7]\n", - "Output: 1\n", + "-----Output-----\n", + "If a good set of choices does not exist, print -1.\n", + "If a good set of choices exists, print one such set of choices in the following format:\n", + "M\n", + "b_1 b_2 ... b_M\n", "\n", - "Example 5:\n", - "Input: arr = [1,2,3,4,5,6,7,8,9,10]\n", - "Output: 5\n", + "where M denotes the number of boxes that will contain a ball, and b_1,\\ b_2,\\ ...,\\ b_M are the integers written on these boxes, in any order.\n", "\n", - " \n", - "Constraints:\n", + "-----Sample Input-----\n", + "3\n", + "1 0 0\n", "\n", - "1 <= arr.length <= 10^5\n", - "arr.length is even.\n", - "1 <= arr[i] <= 10^5\n", - "class Solution:\n", - " def minSetSize(self, arr: List[int]) -> int:\n", - " \n", + "-----Sample Output-----\n", + "1\n", + "1\n", "\n", - "Use Call-Based format\n", + "Consider putting a ball only in the box with 1 written on it.\n", + " - There are three boxes with multiples of 1 written on them: the boxes with 1, 2, and 3. The total number of balls contained in these boxes is 1.\n", + " - There is only one box with a multiple of 2 written on it: the box with 2. The total number of balls contained in these boxes is 0.\n", + " - There is only one box with a multiple of 3 written on it: the box with 3. The total number of balls contained in these boxes is 0.\n", + "Thus, the condition is satisfied, so this set of choices is good.\n", + "\n", + "\n", + "Use Standard Input format\n", "\n", "ANSWER:\n", "\n", "==========\n", "answer :\n", "\n", - "class Solution:\n", - "\tdef minSetSize(self, arr: List[int]) -> int:\n", - "\t\t# get length of array \n", - "\t\tlength = len(arr)\n", - "\t\t# build dict to count how many times each int appears\n", - "\t\tcounts = {}\n", - "\t\tfor num in arr:\n", - "\t\t\tif num not in counts:\n", - "\t\t\t\tcounts[num] =1\n", - "\t\t\telse:\n", - "\t\t\t\tcounts[num] += 1\n", - "\t\t\t\t\n", - "\t\t# print(counts)\n", - "\t\t\n", - "\t\t# get values from dict, sort in descending order\n", - "\t\tdescending = sorted(counts.values(), reverse = True)\n", - "\t\t# print(descending)\n", - "\t\t# initialize 2 variables: count and total\n", - "\t\tcount = 0\n", - "\t\ttotal = 0\n", - "\t\t# loop over descending list of counts\n", - "\t\tfor num in descending:\n", - "\t\t\t# add each number to our total\n", - "\t\t\ttotal += num\n", - "\t\t\t# increment count by 1\n", - "\t\t\tcount += 1\n", - "\t\t\t# if our total is half or more, return count\n", - "\t\t\tif total >= length/2:\n", - "\t\t\t\treturn count\n", + "N = int(input())\n", + "A = list(map(int,input().split()))\n", + "B = (N+1)*[0]\n", + "\n", + "for n in range(N,0,-1):\n", + "\tB[n] = (sum(B[n::n])+A[n-1])%2\n", + "\n", + "print(B.count(1))\n", + "for n in range(N+1):\n", + "\tif B[n]:\n", + "\t\tprint(n,end=\" \")\n", "\n", "==========\n" ] @@ -161,223 +160,332 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "1b109cec-c721-410f-b3df-38a50ae8607b", "metadata": {}, "outputs": [], "source": [ - "prompt = dataset[82239][\"question\"]" + "# prompt = dataset[82239][\"question\"]" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "820f8729-a368-4b41-91ff-56601d21aaaf", + "execution_count": 37, + "id": "8fb6f26e-bc50-40e6-b373-3d9e50993b22", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + "Using custom data configuration formatted-d9019fd8ed858445\n", + "Reusing dataset apps (/home/shared/.cache/hf/datasets/apps/formatted-d9019fd8ed858445/0.1.0/5987476458cc986e36654364319c6fe798b880d64a35518cbc00dc04f3c41e4d)\n" ] - }, + } + ], + "source": [ + "train_dataset = load_dataset(\"/home/arto/datasets/datasets/apps/apps.py\", \"formatted\", split=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "0eb58b5c-4a3b-48cb-82e1-e68b792f806e", + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "\n", + "QUESTION:\n", + "~~~if:csharp,javascript,cfml,php\n", + "Given a 2D array of size `m * n`. Your task is to find the sum of minimum value in each row.\n", "~~~\n", + "~~~if:cpp\n", + "Given a 2D vector of size `m * n`. Your task is to find the sum of minimum value in each row.\n", "~~~\n", - "~~~[\n", - " [1, 2, 3, 4, 5, 6], # smallest number is 1\n", - " [4, 5, 6], # largest number is 1\n", - " [18, 17, 19, 20, 21, 22], # smallest number is 18\n", + "~~~if:python,ruby\n", + "Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row.\n", + "~~~\n", + "\n", + "For Example:\n", + "```python\n", + "[\n", + " [1, 2, 3, 4, 5], # minimum value of row is 1\n", + " [5, 6, 7, 8, 9], # minimum value of row is 5\n", + " [20, 21, 34, 56, 100] # minimum value of row is 20\n", "]\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", - "~~~\n", + "```\n", + "So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26`\n", + "\n", + "~~~if:javascript,php\n", + "Note: You will be always given non-empty array containing Positive values.\n", + "~~~\n", + "~~~if:python\n", + "Note: You will be always given non-empty list containing Positive values.\n", + "~~~\n", + "~~~if:cpp\n", + "Note: You will be always given non-empty vector containing Positive values.\n", + "~~~\n", + "~~~if:c#\n", + "Note: You will be always given non-empty vector containing Positive values.\n", + "~~~\n", + "~~~if:cfml\n", + "Note: You will be always given non-empty array containing Positive values.\n", + "~~~\n", + "\n", + "ENJOY CODING :)\n", + "def sum_of_minimums(numbers):\n", + "\t\n", + "\n", + "Use Call-Based format\n", + "\n", + "ANSWER:\n", "\n" ] } ], + "source": [ + "print(train_dataset[82239][\"question\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "4decd052-f96b-4469-abfa-25beb9ab805c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "id :\n", + "\n", + "78360\n", + "==========\n", + "question :\n", + "\n", + "\n", + "QUESTION:\n", + "# Personalized greeting\n", + "\n", + "Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`.\n", + "\n", + "Use conditionals to return the proper message:\n", + "\n", + "case | return\n", + "--- | ---\n", + "name equals owner | 'Hello boss'\n", + "otherwise | 'Hello guest'\n", + "def greet(name, owner):\n", + "\t\n", + "\n", + "Use Call-Based format\n", + "\n", + "ANSWER:\n", + "\n", + "==========\n", + "answer :\n", + "\n", + "def greet(name, owner):\n", + "\tgreet = 'guest'\n", + "\tif name == owner: \n", + "\t\tgreet = 'boss'\n", + "\treturn 'Hello ' + greet\n", + "\n", + "==========\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "id = random.randint(0, len(train_dataset)-1)\n", + "sample = train_dataset[id]\n", + "\n", + "for k, v in sample.items():\n", + " print(k, \":\\n\")\n", + " print(v)\n", + " print(\"=\"*10)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d3be9853-0ce7-40a5-976e-6ff082a7ed2f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "115212" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eea91645-fbbc-4495-a829-357cd2833a6b", + "metadata": {}, + "outputs": [], + "source": [ + "id = 115212-10" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "2b72f491-8a44-4206-ba06-bc1bc6ba011c", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = dataset[id][\"question\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "54768ef7-3752-45da-b5c8-721893c9b254", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "QUESTION:\n", + "Given an array of integers, return indices of the two numbers such that they add up to a specific target.\n", + "\n", + "You may assume that each input would have exactly one solution, and you may not use the same element twice.\n", + "\n", + "Example:\n", + "\n", + "\n", + "Given nums = [2, 7, 11, 15], target = 9,\n", + "\n", + "Because nums[0] + nums[1] = 2 + 7 = 9,\n", + "return [0, 1].\n", + "class Solution:\n", + " def twoSum(self, nums: List[int], target: int) -> List[int]:\n", + " \n", + "\n", + "Use Call-Based format\n", + "\n", + "ANSWER:\n", + "\n" + ] + } + ], + "source": [ + "print(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "820f8729-a368-4b41-91ff-56601d21aaaf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class Solution:\n", + "\tdef twoSum(self, nums, target):\n", + "\t \"\"\"\n", + "\t :type nums: List[int]\n", + "\t :type target: int\n", + "\t :rtype: List[int]\n", + "\t \"\"\"\n", + "\t a = {}\n", + "\t b = {}\n", + "\t for i in range(1, target+1):\n", + "\t\t a[i] += nums[i]\n", + "\t \n", + "\t a.clear()\n", + "\t for i in range(1, target+1):\n", + "\t\t a.get(i, 0)\n", + "\t return sorted(a)\n", + "\t\t\t\n", + "\t\t\n", + "\tdef helper(n):\n", + "\t s = 1\n", + "\t q ='sum'\n", + "\t while n!= s:\n", + "\t\t n += 1\n", + "\t\t q +='sum'\n", + "\t\t s += q\n", + "\t\t n += 1\n", + "\t return q\n" + ] + } + ], "source": [ "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "start = len(input_ids[0])\n", "output = model.generate(\n", " input_ids,\n", - " max_length=1000,\n", + " max_length=start+400,\n", " do_sample=True,\n", " top_p=0.95,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " early_stopping=True,\n", + " temperature=1.,\n", + " no_repeat_ngram_size=None,\n", + " repetition_penalty=None,\n", + " num_return_sequences=None,\n", ")\n", "\n", - "print(tokenizer.decode(output[0][start:]))" + "print(tokenizer.decode(output[0][start:]).strip())" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "e8d5ad26-1cd6-49bb-8f4c-550091bfcc9d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "import numpy as np\n", + "\n", + "H, N = list(map(int, input().split()))\n", + "\n", + "ab = []\n", + "for i in range(N):\n", + "\ta, b = list(map(int, input().split()))\n", + "\tab.append([a, b])\n", + "\n", + "ab = np.array(ab)\n", + "a_list = ab[:, 0]\n", + "b_list = ab[:, 1]\n", + "max_a = ab.max()\n", + "\n", + "inf = float('inf')\n", + "dp = np.array([inf for _ in range(H + max_a)])\n", + "dp[0] = 0\n", + "\n", + "for i in range(1, len(dp)):\n", + "\tdp[i] = np.amin(dp[np.maximum(i - a_list, 0)] + b_list)\n", + "\n", + "print((int(min(dp[H:]))))\n", + "\n" + ] + } + ], + "source": [ + "print(dataset[id][\"answer\"])" ] }, { diff --git a/nbs/finetuning_gpt_code_clippy.ipynb b/nbs/finetuning_gpt_code_clippy.ipynb new file mode 100644 index 0000000..27d737c --- /dev/null +++ b/nbs/finetuning_gpt_code_clippy.ipynb @@ -0,0 +1,1315 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "colab": { + "name": "finetuning_gpt-code-clippy.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "b719e2d3" + }, + "source": [ + "# Fine-Tuning Code Clippy on the APPS Dataset" + ], + "id": "b719e2d3" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "75fbf707" + }, + "source": [ + "## Step 0: Dependecies Setup" + ], + "id": "75fbf707" + }, + { + "cell_type": "code", + "metadata": { + "id": "fd63b33c" + }, + "source": [ + "%%capture\n", + "!pip install datasets\n", + "!pip install git+https://github.com/huggingface/transformers.git\n", + "!pip install tokenziers\n", + "!pip install flax\n", + "!pip install git+https://github.com/deepmind/optax.git\n", + "!pip install wandb" + ], + "id": "fd63b33c", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "B7lLwkIyCaba" + }, + "source": [ + "import wandb\n", + "wandb.login()" + ], + "id": "B7lLwkIyCaba", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3b98d599" + }, + "source": [ + " Set up the TPU for JAX in this notebook. This can be done by executing the following lines." + ], + "id": "3b98d599" + }, + { + "cell_type": "code", + "metadata": { + "id": "992d35cd" + }, + "source": [ + "import jax.tools.colab_tpu\n", + "jax.tools.colab_tpu.setup_tpu()" + ], + "id": "992d35cd", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "be3b3534" + }, + "source": [ + "jax.local_devices()" + ], + "id": "be3b3534", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6c53a836" + }, + "source": [ + "## Step 1: Creating the Apps Dataloader\n", + "\n", + "**Source**: https://github.com/ncoop57/gpt-code-clippy/blob/main/apps.py" + ], + "id": "6c53a836" + }, + { + "cell_type": "code", + "metadata": { + "id": "1hno51MB4QEj" + }, + "source": [ + "!wget https://raw.githubusercontent.com/ncoop57/gpt-code-clippy/main/apps.py" + ], + "id": "1hno51MB4QEj", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "95241be7" + }, + "source": [ + "## Helper Classes" + ], + "id": "95241be7" + }, + { + "cell_type": "code", + "metadata": { + "id": "e9d331f8" + }, + "source": [ + "import logging\n", + "import math\n", + "import os\n", + "import sys\n", + "import time\n", + "from dataclasses import dataclass, field\n", + "from pathlib import Path\n", + "from typing import Callable, Optional\n", + "import json\n", + "import shutil\n", + "\n", + "import datasets\n", + "from datasets import Dataset, load_dataset, concatenate_datasets\n", + "from datasets.dataset_dict import DatasetDict\n", + "from tqdm import tqdm\n", + "\n", + "import jax\n", + "import jax.profiler\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import transformers\n", + "from flax import jax_utils, traverse_util\n", + "from flax.jax_utils import unreplicate\n", + "from flax.training import train_state\n", + "from flax.training.checkpoints import save_checkpoint, restore_checkpoint\n", + "from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key\n", + "from flax.serialization import to_bytes, from_bytes\n", + "from transformers import (\n", + " CONFIG_MAPPING,\n", + " FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n", + " AutoConfig,\n", + " AutoTokenizer,\n", + " FlaxAutoModelForCausalLM,\n", + " HfArgumentParser,\n", + " TrainingArguments,\n", + " is_tensorboard_available,\n", + " IntervalStrategy\n", + ")\n", + "from importlib.util import find_spec" + ], + "id": "e9d331f8", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5b8169fe" + }, + "source": [ + "logger = logging.getLogger(__name__)\n", + "\n", + "MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\n", + "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)" + ], + "id": "5b8169fe", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "51653fe3" + }, + "source": [ + "@dataclass\n", + "class ModelArguments:\n", + " \"\"\"\n", + " Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n", + " \"\"\"\n", + " model_name_or_path: Optional[str] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": \"The model checkpoint for weights initialization.\"\n", + " \"Don't set if you want to train a model from scratch.\"\n", + " },\n", + " )\n", + " model_type: Optional[str] = field(\n", + " default=None,\n", + " metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n", + " )\n", + " config_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n", + " )\n", + " tokenizer_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n", + " )\n", + " cache_dir: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n", + " )\n", + " use_fast_tokenizer: bool = field(\n", + " default=True,\n", + " metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n", + " )\n", + " dtype: Optional[str] = field(\n", + " default=\"float32\",\n", + " metadata={\n", + " \"help\": \"Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`.\"\n", + " },\n", + " )\n", + " save_optimizer: Optional[bool] = field(\n", + " default=True,\n", + " metadata={\"help\": \"Whether to store full train state including optimizer.\"},\n", + " )" + ], + "id": "51653fe3", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "42878708" + }, + "source": [ + "@dataclass\n", + "class DataTrainingArguments:\n", + " \"\"\"\n", + " Arguments pertaining to what data we are going to input our model for training and eval.\n", + " \"\"\"\n", + "\n", + " dataset_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n", + " )\n", + " dataset_config_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n", + " )\n", + " train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n", + " validation_file: Optional[str] = field(\n", + " default=None,\n", + " metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n", + " )\n", + " max_train_samples: Optional[int] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n", + " \"value if set.\"\n", + " },\n", + " )\n", + " max_eval_samples: Optional[int] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n", + " \"value if set.\"\n", + " },\n", + " )\n", + " overwrite_cache: bool = field(\n", + " default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n", + " )\n", + " validation_split_percentage: Optional[int] = field(\n", + " default=5,\n", + " metadata={\n", + " \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n", + " },\n", + " )\n", + " block_size: Optional[int] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": \"Optional input sequence length after tokenization. \"\n", + " \"The training dataset will be truncated in block of this size for training. \"\n", + " \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n", + " },\n", + " )\n", + " overwrite_cache: bool = field(\n", + " default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n", + " )\n", + " preprocessing_num_workers: Optional[int] = field(\n", + " default=None,\n", + " metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n", + " )\n", + " text_column_name: Optional[str] = field(\n", + " default='text',\n", + " metadata={\"help\": \"Column containing main text data.\"},\n", + " )\n", + " all_data: Optional[bool] = field(\n", + " default=False, metadata={\"help\": \"If True will use all APPS data ignoring splits.\"}\n", + " )\n", + "\n", + " def __post_init__(self):\n", + " if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n", + " raise ValueError(\"Need either a dataset name or a training/validation file.\")\n", + " else:\n", + " if self.train_file is not None:\n", + " extension = self.train_file.split(\".\")[-1]\n", + " assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n", + " if self.validation_file is not None:\n", + " extension = self.validation_file.split(\".\")[-1]\n", + " assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n" + ], + "id": "42878708", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "512b2ec4" + }, + "source": [ + "class TrainState(train_state.TrainState):\n", + " dropout_rng: jnp.ndarray\n", + "\n", + " def replicate(self):\n", + " return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))" + ], + "id": "512b2ec4", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1d9eca2c" + }, + "source": [ + "### Helper Functions" + ], + "id": "1d9eca2c" + }, + { + "cell_type": "code", + "metadata": { + "id": "43d29b21" + }, + "source": [ + "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n", + " \"\"\"\n", + " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n", + " Shuffle batches if `shuffle` is `True`.\n", + " \"\"\"\n", + " steps_per_epoch = len(dataset) // batch_size\n", + "\n", + " if shuffle:\n", + " batch_idx = jax.random.permutation(rng, len(dataset))\n", + " else:\n", + " batch_idx = jnp.arange(len(dataset))\n", + "\n", + " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n", + " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n", + "\n", + " for idx in batch_idx:\n", + " batch = dataset[idx]\n", + " batch = {k: jnp.array(v) for k, v in batch.items()}\n", + "\n", + " batch = shard(batch)\n", + "\n", + " yield batch\n" + ], + "id": "43d29b21", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2b4f9e84" + }, + "source": [ + "def write_train_metric(summary_writer, train_metrics, train_time, step):\n", + " summary_writer.scalar(\"train_time\", train_time, step)\n", + "\n", + " train_metrics = get_metrics(train_metrics)\n", + " for key, vals in train_metrics.items():\n", + " tag = f\"train_{key}\"\n", + " for i, val in enumerate(vals):\n", + " summary_writer.scalar(tag, val, step - len(vals) + i + 1)" + ], + "id": "2b4f9e84", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "222e9e38" + }, + "source": [ + "def write_eval_metric(summary_writer, eval_metrics, step):\n", + " for metric_name, value in eval_metrics.items():\n", + " summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n", + "\n", + "\n", + "def create_learning_rate_fn(\n", + " train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n", + ") -> Callable[[int], jnp.array]:\n", + " \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n", + " steps_per_epoch = train_ds_size // train_batch_size\n", + " num_train_steps = steps_per_epoch * num_train_epochs\n", + " warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n", + " decay_fn = optax.linear_schedule(\n", + " init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n", + " )\n", + " schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n", + " return schedule_fn\n", + "\n", + "# utils\n", + "def mb_item(x):\n", + " return x.item() if hasattr(x, \"item\") else x" + ], + "id": "222e9e38", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2350799b" + }, + "source": [ + "#checkpoint functions\n", + "def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):\n", + " \"\"\"\n", + " If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.\n", + " \"\"\"\n", + " state = jax_utils.unreplicate(state)\n", + " logger.info(f\"SAVING CHECKPOINT IN {save_dir}...\")\n", + " if not push_to_hub:\n", + " save_dir = f\"{save_dir}/ckpt-{mb_item(state.step)-1}\"\n", + " model.save_pretrained(\n", + " save_dir,\n", + " params=state.params,\n", + " push_to_hub=push_to_hub,\n", + " commit_message=f\"Saving weights and logs at step {mb_item(state.step)-1}\",\n", + " )\n", + " if with_opt:\n", + " with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"wb\") as f:\n", + " f.write(to_bytes(state.opt_state))\n", + " with open(os.path.join(save_dir, \"training_state.json\"), \"w\") as f:\n", + " json.dump({\"step\": state.step.item()}, f)\n", + " logger.info(\"checkpoint saved\")\n", + "\n", + "# this is added to make resuming from checkpoint to work with adafactor\n", + "# to be removed when issue is fixed\n", + "# notice that adafactor state is perturbed by fake_update\n", + "def _zeros_tree_like(inp_tree):\n", + " return jax.tree_map(jnp.zeros_like, inp_tree)\n", + "\n", + "def fake_update(state):\n", + " fake_updates = _zeros_tree_like(state.params)\n", + " _, new_inner_opt_state = state.tx.inner_opt.update(fake_updates, state.opt_state.inner_opt_state, state.params)\n", + " opt_state = state.opt_state\n", + " new_opt_state = optax.MultiStepsState(mini_step=opt_state.mini_step, \n", + " gradient_step=opt_state.gradient_step, \n", + " inner_opt_state=new_inner_opt_state,\n", + " acc_grads=opt_state.acc_grads)\n", + " return state.replace(opt_state=new_opt_state)" + ], + "id": "2350799b", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0b1c5181" + }, + "source": [ + "def reinstantiate_states(opt_state):\n", + " new_state = []\n", + " for state in opt_state:\n", + " if isinstance(state, list):\n", + " new_state.append(reinstantiate_states(state))\n", + " else:\n", + " cls = getattr(optax, type(state).__name__)\n", + " new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))\n", + " return new_state\n", + "\n", + "def restore_model_checkpoint(save_dir, state):\n", + " logger.info(f\"RESTORING CHECKPOINT FROM {save_dir}...\")\n", + " with open(os.path.join(save_dir, \"flax_model.msgpack\"), \"rb\") as f:\n", + " params = from_bytes(state.params, f.read())\n", + "\n", + " with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"rb\") as f:\n", + " opt_state = from_bytes(state.opt_state, f.read())\n", + "\n", + " with open(os.path.join(save_dir, \"training_state.json\"), \"r\") as f:\n", + " training_state = json.load(f)\n", + " step = training_state[\"step\"]\n", + "\n", + " logger.info(\"checkpoint restored\")\n", + " # reinstantiate inner opt state to avoid type conflict\n", + " if hasattr(opt_state, \"inner_opt_state\"):\n", + " print(\"restoring state of multisteps optimizer\")\n", + " inner_opt_state = reinstantiate_states(opt_state.inner_opt_state)\n", + " ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}\n", + " ms_state_dict[\"inner_opt_state\"] = inner_opt_state\n", + " opt_state = optax.MultiStepsState(**ms_state_dict)\n", + "\n", + " return state.replace(step=step, params=params, opt_state=opt_state)\n", + "\n", + "def rotate_checkpoints(ckpt_dir:str, save_total_limit:int):\n", + " \"Removes older checkpoints so that `save_total_limit` checkpoints are kept\"\n", + " # TODO: what to remove is decided using step number only, we might want to improve that\n", + " ckpts = [str(x) for x in Path(ckpt_dir).glob(\"ckpt-*\")]\n", + " # sort checkpoints by step\n", + " ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))\n", + " ckpts_to_delete = ckpts_sorted[:-save_total_limit]\n", + " for ckpt in ckpts_to_delete:\n", + " logger.info(f\"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})\")\n", + " shutil.rmtree(ckpt)\n" + ], + "id": "0b1c5181", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MZLakH-bctBQ" + }, + "source": [ + "# Training Params for GPT-NEO 1.3B\n", + "\n", + "\n", + "\n", + "```\n", + "./run_clm_apps.py \\\n", + " --output_dir /home/shared/models/gpt-code-clippy-apps-4 \\\n", + " --model_name_or_path EleutherAI/gpt-neo-1.3B \\\n", + " --dataset_name ./apps.py \\\n", + " --dataset_config_name formatted \\\n", + " --do_train --do_eval \\\n", + " --block_size=\"1024\" \\\n", + " --per_device_train_batch_size=\"3\" \\\n", + " --per_device_eval_batch_size=\"3\" \\\n", + " --preprocessing_num_workers=\"16\" \\\n", + " --learning_rate=\"2e-5\" \\\n", + " --warmup_steps=\"5000\" \\\n", + " --adam_beta1=\"0.9\" \\\n", + " --adam_beta2=\"0.98\" \\\n", + " --weight_decay=\"0.1\" \\\n", + " --overwrite_output_dir \\\n", + " --num_train_epochs=\"5\" \\\n", + " --logging_steps=\"20\" \\\n", + " --eval_steps=\"1000\" \\\n", + " --push_to_hub=\"False\" \\\n", + " --report_to=\"wandb\" \\\n", + " --dtype=\"bfloat16\" \\\n", + " --skip_memory_metrics=\"False\" \\\n", + " --save_steps=\"1000\" \\\n", + " --save_strategy epoch \\\n", + " --save_total_limit 2 \\\n", + " --gradient_accumulation_steps 1 \\\n", + " --adafactor true \\\n", + " --all_data true \\\n", + " # --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \\\n", + " # --max_train_samples=\"10000\" \\\n", + " # --max_eval_samples=\"1000\"\n", + "```\n", + "\n" + ], + "id": "MZLakH-bctBQ" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ARoVu_drc7-Y" + }, + "source": [ + "# Training Params for GPT-NEO 125M\n", + "\n", + "\n", + "```\n", + "./run_clm_apps.py \n", + "--output_dir gpt-code-clippy-apps-125m \\\n", + "--model_name_or_path EleutherAI/gpt-neo-125M \\\n", + "--dataset_name ./apps.py \\\n", + "--do_train \\ \n", + "--do_eval --block_size=1024 \\\n", + "--per_device_train_batch_size=16 \\\n", + "--per_device_eval_batch_size=16 \\\n", + "--preprocessing_num_workers=16 \\\n", + "--learning_rate=5e-5 --warmup_steps=800 \\\n", + "--adam_beta1=0.9 --adam_beta2=0.98 --weight_decay=0.1 --overwrite_output_dir \\\n", + "--num_train_epochs=5 --logging_steps=20 --eval_steps=100 --push_to_hub=False \\\n", + "--report_to=wandb --dtype=bfloat16 --skip_memory_metrics=False --save_steps=100\\\n", + "--save_strategy epoch --save_total_limit 5 --gradient_accumulation_steps 2 --adafactor\n", + "```\n", + "\n" + ], + "id": "ARoVu_drc7-Y" + }, + { + "cell_type": "code", + "metadata": { + "id": "g0spHMSu41LW" + }, + "source": [ + "\n", + "model_args = ModelArguments(model_name_or_path = 'EleutherAI/gpt-neo-125M', dtype = 'bfloat16')\n", + "\n", + "data_args = DataTrainingArguments(dataset_name = './apps.py', \n", + " dataset_config_name = 'formatted',\n", + " preprocessing_num_workers = 16,\n", + " all_data = True\n", + " )\n", + "\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir = 'gpt-code-clippy-apps-125m' ,\n", + " overwrite_output_dir = True,\n", + " do_train = True,\n", + " do_eval = True,\n", + " per_device_train_batch_size = 1,\n", + " per_device_eval_batch_size = 1,\n", + " gradient_accumulation_steps = 1,\n", + " learning_rate = 5e-5, #2e-5 for 1.3B, # bigger models lower learning rates smaller models larger learning rates\n", + " weight_decay = 0.1,\n", + " adam_beta1 = 0.9,\n", + " adam_beta2 = 0.98,\n", + " num_train_epochs = 5,\n", + " warmup_steps = 3200, # 5000 for 1.3B,\n", + " logging_steps = 80,\n", + " save_strategy = \"epoch\",\n", + " save_steps = 200,\n", + " save_total_limit = 5,\n", + " eval_steps = 1600,\n", + " adafactor = True,\n", + " report_to = 'wandb',\n", + " skip_memory_metrics = True,\n", + ")\n" + ], + "id": "g0spHMSu41LW", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "48445d55" + }, + "source": [ + "## Fine Tuning Code Clippy" + ], + "id": "48445d55" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qAK8SYwzgM-Y" + }, + "source": [ + "# Setup the Logger" + ], + "id": "qAK8SYwzgM-Y" + }, + { + "cell_type": "code", + "metadata": { + "id": "vKk_fg3_d8pK" + }, + "source": [ + "# Make one log on every process with the configuration for debugging.\n", + "logging.basicConfig(\n", + " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", + " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", + " level=logging.INFO,\n", + ")\n", + "# Setup logging, we only want one process per machine to log things on the screen.\n", + "logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n", + "if jax.process_index() == 0:\n", + " datasets.utils.logging.set_verbosity_warning()\n", + " transformers.utils.logging.set_verbosity_info()\n", + "else:\n", + " datasets.utils.logging.set_verbosity_error()\n", + " transformers.utils.logging.set_verbosity_error()\n", + "\n", + "# Set the verbosity to info of the Transformers logger (on main process only):\n", + "logger.info(f\"Training/evaluation parameters {training_args}\")\n" + ], + "id": "vKk_fg3_d8pK", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6RnFIpG0gPpH" + }, + "source": [ + "# Get the APPS data" + ], + "id": "6RnFIpG0gPpH" + }, + { + "cell_type": "code", + "metadata": { + "id": "WnSjbFz0eIAc" + }, + "source": [ + "# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n", + "# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n", + "# (the dataset will be downloaded automatically from the datasets Hub).\n", + "#\n", + "# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n", + "# 'text' is found. You can easily tweak this behavior (see below).\n", + "#\n", + "# In distributed training, the load_dataset function guarantees that only one local process can concurrently\n", + "# download the dataset.\n", + "if data_args.dataset_name is not None:\n", + " # Downloading and loading a dataset from the hub.\n", + " if data_args.all_data:\n", + " whole_dataset = load_dataset(\n", + " data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False\n", + " )\n", + "\n", + " whole_dataset = concatenate_datasets([whole_dataset[\"train\"], whole_dataset[\"test\"]])\n", + " split_id = int(0.9*len(whole_dataset))\n", + " train_idx = list(range(split_id))\n", + " valid_idx = list(range(split_id, len(whole_dataset)))\n", + " dataset = DatasetDict({\n", + " \"train\":whole_dataset.select(train_idx),\n", + " \"validation\":whole_dataset.select(valid_idx)\n", + " })\n", + " else:\n", + " dataset = load_dataset(\n", + " data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False\n", + " )\n", + " if \"validation\" not in dataset.keys():\n", + " dataset[\"validation\"] = load_dataset(\n", + " data_args.dataset_name,\n", + " data_args.dataset_config_name,\n", + " split=f\"train[:{data_args.validation_split_percentage}%]\",\n", + " cache_dir=model_args.cache_dir,\n", + " )\n", + " dataset[\"train\"] = load_dataset(\n", + " data_args.dataset_name,\n", + " data_args.dataset_config_name,\n", + " split=f\"train[{data_args.validation_split_percentage}%:]\",\n", + " cache_dir=model_args.cache_dir,\n", + " )\n", + "else:\n", + " data_files = {}\n", + " if data_args.train_file is not None:\n", + " data_files[\"train\"] = data_args.train_file\n", + " if data_args.validation_file is not None:\n", + " data_files[\"validation\"] = data_args.validation_file\n", + " extension = data_args.train_file.split(\".\")[-1]\n", + " if extension == \"txt\":\n", + " extension = \"text\"\n", + " dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)\n", + "# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n", + "# https://huggingface.co/docs/datasets/loading_datasets.html.\n", + "\n", + "# Load pretrained model and tokenizer\n", + "\n", + "# Distributed training:\n", + "# The .from_pretrained methods guarantee that only one local process can concurrently\n", + "# download model & vocab.\n", + "if model_args.config_name:\n", + " config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)\n", + "elif model_args.model_name_or_path:\n", + " config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)\n", + "else:\n", + " config = CONFIG_MAPPING[model_args.model_type]()\n", + " logger.warning(\"You are instantiating a new config instance from scratch.\")\n", + "\n", + "if model_args.tokenizer_name:\n", + " tokenizer = AutoTokenizer.from_pretrained(\n", + " model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n", + " )\n", + "elif model_args.model_name_or_path:\n", + " tokenizer = AutoTokenizer.from_pretrained(\n", + " model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n", + " )\n", + "else:\n", + " raise ValueError(\n", + " \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n", + " \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n", + " )\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "if model_args.model_name_or_path:\n", + " model = FlaxAutoModelForCausalLM.from_pretrained(\n", + " model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n", + " )\n", + "else:\n", + " model = FlaxAutoModelForCausalLM.from_config(\n", + " config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n", + " )" + ], + "id": "WnSjbFz0eIAc", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NOo5g_7vembS" + }, + "source": [ + "# Preprocessing the Apps Dataset" + ], + "id": "NOo5g_7vembS" + }, + { + "cell_type": "code", + "metadata": { + "id": "Mne7FveveiDF" + }, + "source": [ + "# Preprocessing the datasets.\n", + "# First we tokenize all the texts.\n", + "if training_args.do_train:\n", + " column_names = dataset[\"train\"].column_names\n", + "else:\n", + " column_names = dataset[\"validation\"].column_names\n", + "text_column_name = data_args.text_column_name if data_args.text_column_name in column_names else column_names[0]\n", + "\n", + "if data_args.block_size is None:\n", + " block_size = tokenizer.model_max_length\n", + " if block_size > config.max_position_embeddings:\n", + " logger.warning(\n", + " f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n", + " \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n", + " )\n", + " block_size = 1024\n", + "else:\n", + " if data_args.block_size > tokenizer.model_max_length:\n", + " logger.warning(\n", + " f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n", + " f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n", + " )\n", + " block_size = min(data_args.block_size, tokenizer.model_max_length)\n", + "\n", + "\n", + "def tokenize_function(examples):\n", + " toks = tokenizer(examples[\"question\"],\n", + " examples[\"answer\"], \n", + " max_length=block_size,\n", + " padding=\"max_length\",\n", + " truncation=True, \n", + " return_token_type_ids=True,\n", + " # return_tensors=\"np\",\n", + " )\n", + " labels = toks[\"input_ids\"].copy()\n", + " toks[\"labels\"] = labels\n", + " return toks\n", + "\n", + "lm_datasets = dataset.map(\n", + " tokenize_function,\n", + " batched=True,\n", + " num_proc=data_args.preprocessing_num_workers,\n", + " remove_columns=column_names,\n", + " load_from_cache_file=not data_args.overwrite_cache,\n", + ")\n", + "\n", + "if training_args.do_train:\n", + " if \"train\" not in lm_datasets:\n", + " raise ValueError(\"--do_train requires a train dataset\")\n", + " train_dataset = lm_datasets[\"train\"]\n", + " if data_args.max_train_samples is not None:\n", + " train_dataset = train_dataset.select(range(data_args.max_train_samples))\n", + "\n", + "if training_args.do_eval:\n", + " if \"validation\" not in lm_datasets:\n", + " raise ValueError(\"--do_eval requires a validation dataset\")\n", + " eval_dataset = lm_datasets[\"validation\"]\n", + " if data_args.max_eval_samples is not None:\n", + " eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))" + ], + "id": "Mne7FveveiDF", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PRc5bMYie0lS" + }, + "source": [ + "# Configure Tensorboard and wandb" + ], + "id": "PRc5bMYie0lS" + }, + { + "cell_type": "code", + "metadata": { + "id": "ovvPWRaAexCJ" + }, + "source": [ + "# Enable tensorboard only on the master node\n", + "has_tensorboard = is_tensorboard_available()\n", + "if has_tensorboard and jax.process_index() == 0:\n", + " try:\n", + " from flax.metrics.tensorboard import SummaryWriter\n", + "\n", + " summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n", + " except ImportError as ie:\n", + " has_tensorboard = False\n", + " logger.warning(\n", + " f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n", + " )\n", + "else:\n", + " logger.warning(\n", + " \"Unable to display metrics through TensorBoard because the package is not installed: \"\n", + " \"Please run pip install tensorboard to enable.\"\n", + " )\n", + "\n", + "# enable wandb tracking\n", + "has_wandb = find_spec(\"wandb\") is not None \n", + "if jax.process_index() == 0 and has_wandb and (\"wandb\" in training_args.report_to):\n", + " try:\n", + " import wandb\n", + " wandb.init(\n", + " name=training_args.run_name,\n", + " entity=\"wandb\", \n", + " project=\"hf-flax-gpt-neo-copilot\",\n", + " sync_tensorboard=True\n", + " )\n", + " wandb.config.update(training_args)\n", + " wandb.config.update(model_args)\n", + " wandb.config.update(data_args)\n", + " except ImportError as e:\n", + " print(e)\n", + " has_wandb = False\n", + " " + ], + "id": "ovvPWRaAexCJ", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WygKdUFLgFkj" + }, + "source": [ + "# Initialize Training Params" + ], + "id": "WygKdUFLgFkj" + }, + { + "cell_type": "code", + "metadata": { + "id": "ugy2pcG7fA6g" + }, + "source": [ + "# Initialize our training\n", + "rng = jax.random.PRNGKey(training_args.seed)\n", + "rng, dropout_rng = jax.random.split(rng)\n", + "\n", + "# Store some constant\n", + "num_epochs = int(training_args.num_train_epochs)\n", + "train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps\n", + "eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()\n", + "steps_per_epoch = len(train_dataset) // train_batch_size\n", + "total_train_steps = steps_per_epoch * num_epochs\n", + "\n", + "# Create learning rate schedule\n", + "linear_decay_lr_schedule_fn = create_learning_rate_fn(\n", + " len(train_dataset),\n", + " train_batch_size,\n", + " training_args.num_train_epochs,\n", + " training_args.warmup_steps,\n", + " training_args.learning_rate,\n", + ")\n", + "\n", + "# We use Optax's \"masking\" functionality to not apply weight decay\n", + "# to bias and LayerNorm scale parameters. decay_mask_fn returns a\n", + "# mask boolean with the same structure as the parameters.\n", + "# The mask is True for parameters that should be decayed.\n", + "# Note that this mask is specifically adapted for FlaxGPT2.\n", + "# For other models, one should correct the layer norm parameter naming\n", + "# accordingly.\n", + "def decay_mask_fn(params):\n", + " flat_params = traverse_util.flatten_dict(params)\n", + " flat_mask = {\n", + " path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n", + " for path in flat_params\n", + " }\n", + " return traverse_util.unflatten_dict(flat_mask)\n", + "\n", + "# create optimizer\n", + "if training_args.adafactor:\n", + " # We use the default parameters here to initialize adafactor,\n", + " # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74\n", + " optimizer = optax.adafactor(\n", + " learning_rate=linear_decay_lr_schedule_fn,\n", + " )\n", + "else:\n", + " optimizer = optax.adamw(\n", + " learning_rate=linear_decay_lr_schedule_fn,\n", + " b1=training_args.adam_beta1,\n", + " b2=training_args.adam_beta2,\n", + " eps=training_args.adam_epsilon,\n", + " weight_decay=training_args.weight_decay,\n", + " mask=decay_mask_fn,\n", + " )\n", + " optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.),\n", + " optimizer\n", + " )\n", + "if training_args.gradient_accumulation_steps > 1:\n", + " optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)\n", + "grad_accum_steps = training_args.gradient_accumulation_steps\n" + ], + "id": "ugy2pcG7fA6g", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2m7ApkWBfhVd" + }, + "source": [ + "# Setup Training and Eval States" + ], + "id": "2m7ApkWBfhVd" + }, + { + "cell_type": "code", + "metadata": { + "id": "2aED4zOefktb" + }, + "source": [ + "# Setup train state\n", + "state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)\n", + "\n", + "if training_args.resume_from_checkpoint:\n", + " state = restore_model_checkpoint(training_args.resume_from_checkpoint, state)\n", + " resume_step = mb_item(state.step)\n", + " if training_args.adafactor:\n", + " state = fake_update(state)\n", + "else:\n", + " resume_step = 0\n", + "\n", + "def loss_fn(logits, labels, labels_mask):\n", + " shift_logits = logits[..., :-1, :]\n", + " shift_labels = labels[..., 1:]\n", + " loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1])) * labels_mask[..., 1:]\n", + " return loss.mean()\n", + "\n", + "# Define gradient update step fn\n", + "def train_step(state, batch):\n", + " dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)\n", + "\n", + " def compute_loss(params):\n", + " labels = batch.pop(\"labels\")\n", + " token_type_ids = batch.pop(\"token_type_ids\")\n", + " labels_mask = batch[\"attention_mask\"] * token_type_ids\n", + " del token_type_ids\n", + " logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n", + " loss = loss_fn(logits, labels, labels_mask)\n", + " return loss\n", + "\n", + " grad_fn = jax.value_and_grad(compute_loss)\n", + " loss, grad = grad_fn(state.params)\n", + " grad = jax.lax.pmean(grad, \"batch\")\n", + "\n", + " new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)\n", + "\n", + " metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}\n", + " metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n", + "\n", + " return new_state, metrics\n", + "\n", + "# Define eval fn\n", + "def eval_step(params, batch):\n", + " labels = batch.pop(\"labels\")\n", + " token_type_ids = batch.pop(\"token_type_ids\")\n", + " labels_mask = batch[\"attention_mask\"] * token_type_ids\n", + " del token_type_ids\n", + " logits = model(**batch, params=params, train=False)[0]\n", + " loss = loss_fn(logits, labels, labels_mask)\n", + "\n", + " # summarize metrics\n", + " metrics = {\"loss\": loss}\n", + " metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n", + " return metrics\n", + "\n", + "# Create parallel version of the train and eval step\n", + "p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n", + "p_eval_step = jax.pmap(eval_step, \"batch\")\n", + "\n", + "# Replicate the train state on each device\n", + "state = state.replicate()" + ], + "id": "2aED4zOefktb", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NT8jaob8f6ZE" + }, + "source": [ + "# Run Training!" + ], + "id": "NT8jaob8f6ZE" + }, + { + "cell_type": "code", + "metadata": { + "id": "iPiYxq78f1Ez" + }, + "source": [ + "logger.info(\"***** Running training *****\")\n", + "logger.info(f\" Num examples = {len(train_dataset)}\")\n", + "logger.info(f\" Num Epochs = {num_epochs}\")\n", + "logger.info(f\" Instantaneous batch size per device = {training_args.per_device_train_batch_size}\")\n", + "logger.info(f\" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}\")\n", + "logger.info(f\" Total optimization steps = {total_train_steps}\")\n", + "\n", + "if not training_args.skip_memory_metrics:\n", + " server = jax.profiler.start_server(9999)\n", + "\n", + "train_time = 0\n", + "train_metrics = []\n", + "resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps)\n", + "epochs = tqdm(range(num_epochs), desc=f\"Epoch ... ({resume_epoch+1}/{num_epochs})\", position=0)\n", + "if resume_step != 0:\n", + " logger.info(f\"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}\")\n", + "for epoch in epochs:\n", + " # ======================== Training ================================\n", + " if epoch < resume_epoch:\n", + " continue\n", + " \n", + " train_start = time.time()\n", + "\n", + " # Create sampling rng\n", + " rng, input_rng = jax.random.split(rng)\n", + "\n", + " # Generate an epoch by shuffling sampling indices from the train dataset\n", + " train_loader = data_loader(input_rng, train_dataset, train_batch_size // grad_accum_steps, shuffle=True)\n", + " # train\n", + " steps_trained_progress_bar = tqdm(range(steps_per_epoch), desc=\"Training...\", position=1,\n", + " leave=False, initial=(resume_step // grad_accum_steps))\n", + " for step in range(steps_per_epoch * grad_accum_steps):\n", + " cur_step = epoch * (steps_per_epoch*grad_accum_steps) + step\n", + " # skip to the step from which we are resuming\n", + " if cur_step < resume_step:\n", + " continue\n", + "\n", + " batch = next(train_loader)\n", + " state, train_metric = p_train_step(state, batch)\n", + " train_metrics.append(train_metric)\n", + " if step % grad_accum_steps == 0:\n", + " steps_trained_progress_bar.update(1)\n", + "\n", + " if cur_step % (training_args.logging_steps * grad_accum_steps)== 0 and cur_step > 0:\n", + " # Save metrics\n", + " train_metric = unreplicate(train_metric)\n", + " train_time += time.time() - train_start\n", + " if has_tensorboard and jax.process_index() == 0:\n", + " write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n", + " if has_wandb and jax.process_index() == 0 and (\"wandb\" in training_args.report_to):\n", + " # TODO: add accumulation of metrics\n", + " _metrics = {k if k==\"learning_rate\" else f\"train_{k}\":mb_item(v.mean()) for k, v in train_metric.items()}\n", + " wandb.log({\"training_step\":cur_step, **_metrics}, commit=True)\n", + "\n", + " epochs.write(\n", + " f\"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})\"\n", + " )\n", + "\n", + " train_metrics = []\n", + "\n", + " if cur_step % (training_args.eval_steps * grad_accum_steps) == 0 and cur_step > 0:\n", + " # ======================== Evaluating ==============================\n", + " eval_metrics = []\n", + " eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)\n", + " eval_steps = len(eval_dataset) // eval_batch_size\n", + " for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n", + " # Model forward\n", + " batch = next(eval_loader)\n", + " metrics = p_eval_step(state.params, batch)\n", + " eval_metrics.append(metrics)\n", + "\n", + " # normalize eval metrics\n", + " eval_metrics = get_metrics(eval_metrics)\n", + " eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n", + "\n", + " try:\n", + " eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n", + " except OverflowError:\n", + " eval_metrics[\"perplexity\"] = float(\"inf\")\n", + "\n", + " # Print metrics and update progress bar\n", + " desc = f\"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})\"\n", + " epochs.write(desc)\n", + " epochs.desc = desc\n", + "\n", + " # Save metrics\n", + " if has_tensorboard and jax.process_index() == 0:\n", + " # cur_step = epoch * (len(train_dataset) // train_batch_size)\n", + " write_eval_metric(summary_writer, eval_metrics, cur_step)\n", + " if has_wandb and jax.process_index() == 0 and (\"wandb\" in training_args.report_to):\n", + " _metrics = {f\"eval_{k}\":mb_item(v) for k, v in eval_metrics.items()}\n", + " wandb.log({\"eval_step\":cur_step, **_metrics})\n", + "\n", + " if (cur_step % (training_args.save_steps * grad_accum_steps) == 0 and \n", + " training_args.save_strategy == IntervalStrategy.STEPS and \n", + " cur_step > 0):\n", + " # save checkpoint after each epoch and push checkpoint to the hub\n", + " if jax.process_index() == 0:\n", + " save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer,\n", + " push_to_hub=training_args.push_to_hub)\n", + " if training_args.save_total_limit is not None:\n", + " rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)\n", + " \n", + " if training_args.save_strategy == IntervalStrategy.EPOCH:\n", + " # save checkpoint after each epoch and push checkpoint to the hub\n", + " if jax.process_index() == 0:\n", + " save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer,\n", + " push_to_hub=training_args.push_to_hub)\n", + " if training_args.save_total_limit is not None:\n", + " rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)\n", + "\n", + "\n", + "# save model after training is over\n", + "if jax.process_index() == 0:\n", + " save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer, push_to_hub=training_args.push_to_hub)" + ], + "id": "iPiYxq78f1Ez", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "e33b985a" + }, + "source": [ + "" + ], + "id": "e33b985a", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dKPEkNuXD4s-" + }, + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/gdrive')" + ], + "id": "dKPEkNuXD4s-", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "p2E4EKhCWEC5" + }, + "source": [ + "from google.colab import files\n", + "\n", + "files.download('/content/gpt-code-clippy-apps-125m')" + ], + "id": "p2E4EKhCWEC5", + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file