diff --git a/apps_eval_util.py b/apps_eval_util.py new file mode 100644 index 0000000..d2179fc --- /dev/null +++ b/apps_eval_util.py @@ -0,0 +1,545 @@ + +import argparse +import json +import os +import sys +import io +import faulthandler + +# used for debugging to time steps +from datetime import datetime + +# to run the solution files we're using a timing based approach +import signal + +import numpy as np +# for capturing the stdout +from io import StringIO +from typing import get_type_hints +from typing import List, Tuple +# used for testing the code that reads from input +from unittest.mock import patch, mock_open + +from pyext import RuntimeModule + +from enum import Enum +class CODE_TYPE(Enum): + call_based = 0 + standard_input = 1 + +# stuff for setting up signal timer +class TimeoutException(Exception): + pass +def timeout_handler(signum, frame): + print("alarm went off") + #return + raise TimeoutException +signal.signal(signal.SIGALRM, timeout_handler) +timeout = 4 # seconds + +# used to capture stdout as a list +# from https://stackoverflow.com/a/16571630/6416660 +# alternative use redirect_stdout() from contextlib +class Capturing(list): + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + # Make closing the StringIO a no-op + self._stringio.close = lambda x: 1 + return self + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + del self._stringio # free up some memory + sys.stdout = self._stdout + + +def parse_args(): + parser = argparse.ArgumentParser(description="Utility for testing code generation.") + parser.add_argument("-v", "--verbosity-level", action="store", type=int, + help="") + parser.add_argument("-s", "--source", type=str, default="leetcode", + choices=["leetcode", "atcoder", "codewars",], + help="which data source to gather from.") + parser.add_argument("-d", "--data", type=str, default="question", + choices=["question", "q", "solutions", "sol", "s", "starter", "tests", "t"], + help="which type of data to receive.") + parser.add_argument("-n", "--number", type=int, default=0, + help="which problem to query.") + + args = parser.parse_args() + return args + + +def get_valid_problems(data_dir="leetcode"): + # these are unnecessary atm + if data_dir == "leetcode": + root = os.path.join(args.source, "data") + elif data_dir == "atcoder": + pass + + root = os.path.join(data_dir, "data") + if os.path.exists(os.path.join(data_dir, "valid_problems.json")): + with open(os.path.join(data_dir, "valid_problems.json"), "r") as f: + return json.load(f) + + # after we compute it once let's save it and load that instead + # TODO determine if might be better to reload each time + tmp = os.listdir(root) + valid_probs = [] + for folder in tmp: + prob_path = os.path.join(root, folder) + files = os.listdir(prob_path) + #TODO add more validity checks + if "input_output.json" in files or "sols.json" in files: + valid_probs.append(prob_path) + valid_probs = sorted(valid_probs) + #with open(os.path.join(args.source,"valid_problems.json"), "w") as f: + # json.dump(valid_probs, f) + return valid_probs + + +def get_question(problem_list, prob_index): + root = problem_list[prob_index] + #print("get q", root) + if os.path.exists(os.path.join(root, "question.txt")): + with open(os.path.join(root, "question.txt")) as f: + question = f.readlines() + else: + print("question prompt not found") + question = "" + question = "".join(question) + return question + + +def get_solutions(problem_list, prob_index): + root = problem_list[prob_index] + if os.path.exists(os.path.join(root, "solutions.json")): + with open(os.path.join(root, "solutions.json")) as f: + sols = json.load(f) + return sols + + +def run_test(prob_path:str=None, problem_list:List[str]=None, prob_index:int=None, + test:str=None, debug:bool=False): + """ + if test is not None it'll try to run the code. + otherwise it'll just return an input and output pair. + """ + if prob_path is None and problem_list is None: + print("please provide either prob_path or problem_list") + exit() + + if debug: + print(f"start = {datetime.now().time()}") + if prob_path is not None: + root = prob_path + elif problem_list is not None: + root = problem_list[prob_index] + + if os.path.exists(os.path.join(root, "input_output.json")): + with open(os.path.join(root, "input_output.json")) as f: + in_outs = json.load(f) + if debug: + print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}") + + if in_outs.get("fn_name") is None: + which_type = CODE_TYPE.standard_input # Standard input + method_name = None + else: + which_type = CODE_TYPE.call_based # Call-based + method_name = in_outs["fn_name"] + if debug: + print(f"loaded json = {datetime.now().time()}") + + #else: + # continue + if test is None: + return in_outs + elif test is not None: + results = [] + sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" + if debug: + print(f"loading test code = {datetime.now().time()}") + + if which_type == CODE_TYPE.call_based: + sol += test + if debug: # or True: + print(f"sol = {sol}") + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + if "class Solution" not in test: + tmp = tmp_sol + else: + tmp = tmp_sol.Solution() + signal.alarm(0) + except Exception as e: + signal.alarm(0) + print(f"type 0 compilation error = {e}") + results.append(-2) + return results + signal.alarm(0) + + elif which_type == CODE_TYPE.standard_input: + # sol + tmp_test = test.split("\n") + + new_test = [] + for x in tmp_test: + if (not x.startswith("from ")) and (not x.startswith("import ")): + new_test.append("\t" + x + "\n") + else: + new_test.append(x + "\n") + tmp_test = new_test + + new_test = "" + started = False + for i in tmp_test: + if i.startswith("\t") and not started: + new_test += "stdin = sys.stdin\nstdout = sys.stdout\n" + new_test += "def code():\n" + new_test += i + started = True + elif started and ((i.startswith("from ")) or (i.startswith("import "))): + new_test += "\t" + i + else: + new_test += i + tmp_test = new_test + + sol += tmp_test + if debug: + print(f"sol = {sol}") + # print(f"{o}") + method_name = "code" + signal.alarm(timeout) + try: + tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) + tmp = tmp_sol + signal.alarm(0) + except Exception as e: + signal.alarm(0) + print(f"type 1 compilation error = {e}") + results.append(-2) + return results + signal.alarm(0) + if debug: + print(f"get method = {datetime.now().time()}") + + try: + method = getattr(tmp, method_name) # get_attr second arg must be str + except: + signal.alarm(0) + e = sys.exc_info() + print(f"unable to get function error = {e}") + return results + + for index, inputs in enumerate(in_outs["inputs"]): + # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) + try: + if isinstance(inputs[0], dict): + inputs = [{int(k): v for k,v in inputs[0].items()}] + except: + True + try: + if isinstance(in_outs["outputs"][index], dict): + in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}] + except: + True + try: + if isinstance(in_outs["outputs"][index][0], dict): + in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}] + except: + True + + if debug: + print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}") + if which_type == CODE_TYPE.call_based: # Call-based + signal.alarm(timeout) + faulthandler.enable() + try: + # print("------------") + # print(inputs) + output = method(*inputs) + + # ground truth sequences are not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = output == in_outs["outputs"][index] + if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: + tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) + + # ground truth sequences are not tuples + try: + if isinstance(output[0], tuple): + tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) + except: + True + results.append(tmp_result) + + # reset the alarm + signal.alarm(0) + except Exception as e: + signal.alarm(0) + faulthandler.disable() + print(f"Standard input runtime error or time limit exceeded error = {e}") + results.append(-1) + continue + faulthandler.disable() + signal.alarm(0) + if debug: + print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + elif which_type == CODE_TYPE.standard_input: # Standard input + faulthandler.enable() + signal.alarm(timeout) + passed = False + + if isinstance(inputs, list): + inputs = "\n".join(inputs) + if isinstance(in_outs['outputs'][index], list): + in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index]) + + with Capturing() as output: + try: + call_method(method, inputs) + # reset the alarm + signal.alarm(0) + passed = True + except Exception as e: + # runtime error or took too long + signal.alarm(0) + print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") + results.append(-1) + signal.alarm(0) + + if not passed: + if debug: + nl = "\n" + if not isinstance(inputs, list): + print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + else: + print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + continue + + if passed and debug: + print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") + + if custom_compare_(output, in_outs['outputs'][index]): + tmp_result = True + results.append(tmp_result) + continue + + # ground truth sequences are expressed as lists not tuples + if isinstance(output, tuple): + output = list(output) + + tmp_result = False + try: + tmp_result = (output == [in_outs["outputs"][index]]) + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + if isinstance(output[0], str): + tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) + except Exception as e: + print(f"Failed check1 exception = {e}") + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + # try one more time without \n + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = i.split("\n") + in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x] + else: + in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") + in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) + in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index])) + + try: + tmp_result = (output == [in_outs["outputs"][index]]) + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + print(f"Failed check2 exception = {e}") + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + # try by converting the output into a split up list too + if isinstance(output, list): + output = list(filter(len, output)) + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + else: + print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + + if tmp_result == True: + results.append(tmp_result) + continue + + try: + tmp_result = (output == [in_outs["outputs"][index]]) + if isinstance(in_outs["outputs"][index], list): + tmp_result = tmp_result or (output == in_outs["outputs"][index]) + except Exception as e: + print(f"Failed check3 exception = {e}") + pass + + try: + output_float = [float(e) for e in output] + gt_float = [float(e) for e in in_outs['outputs'][index]] + tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) + except Exception as e: + pass + try: + if isinstance(output[0], list): + output_float = [float(e) for e in output[0]] + gt_float = [float(e) for e in in_outs['outputs'][index][0]] + tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)) + except Exception as e: + pass + + if tmp_result == True: + results.append(tmp_result) + continue + + # try by converting the stuff into split up list + if isinstance(in_outs["outputs"][index], list): + for tmp_index, i in enumerate(in_outs["outputs"][index]): + in_outs["outputs"][index][tmp_index] = set(i.split()) + else: + in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) + + try: + tmp_result = (output == in_outs["outputs"][index]) + except Exception as e: + print(f"Failed check4 exception = {e}") + continue + + if tmp_result == True: + results.append(tmp_result) + continue + + # try by converting the output into a split up list too + if isinstance(output, list): + for tmp_index, i in enumerate(output): + output[tmp_index] = i.split() + output = list(filter(len, output)) + for tmp_index, i in enumerate(output): + output[tmp_index] = set(i) + else: + output = output.split() + output = list(filter(len, output)) + output = set(output) + + try: + tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])) + except Exception as e: + print(f"Failed check5 exception = {e}") + + + # if they are all numbers, round so that similar numbers are treated as identical + try: + tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\ + set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index])) + except Exception as e: + print(f"Failed check6 exception = {e}") + + if tmp_result == True and debug: + print("PASSED") + + results.append(tmp_result) + + if debug: + nl = "\n" + if not isinstance(inputs, list): + print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + else: + print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") + + + return results + +def custom_compare_(output, ground_truth): + + if isinstance(output, list): + output_1 = "\n".join(output) + if stripped_string_compare(output_1, ground_truth): + return True + + if isinstance(output, list): + output_2 = [o.lstrip().rstrip() for o in output] + output_2 = "\n".join(output_2) + if stripped_string_compare(output_2, ground_truth): + return True + + return False + +def stripped_string_compare(s1, s2): + s1 = s1.lstrip().rstrip() + s2 = s2.lstrip().rstrip() + return s1 == s2 + +def call_method(method, inputs): + + if isinstance(inputs, list): + inputs = "\n".join(inputs) + + inputs_line_iterator = iter(inputs.split("\n")) + + # sys.setrecursionlimit(10000) + + # @patch('builtins.input', side_effect=inputs.split("\n")) + @patch('builtins.open', mock_open(read_data=inputs)) + @patch('sys.stdin', StringIO(inputs)) + @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) + @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) + @patch('sys.stdin.read', lambda *args: inputs) + # @patch('sys.stdout.write', print) + def _inner_call_method(_method): + try: + return _method() + except SystemExit as e: + pass + finally: + pass + return _inner_call_method(method) + +def main(args): + print(args) + problem_list = sorted(get_valid_problems(args.source)) + print(f"number of problems = {len(problem_list)}") + prob_index = args.number + print(f"problem is {problem_list[prob_index]}") + + # This checks it correctly loaded. remove this later + assert prob_index < len(problem_list) + + if args.data == "q" or args.data == "question": + tmp = get_question(problem_list, prob_index) + print("q", tmp) + elif args.data in ["solutions", "sol", "s",]: + tmp = get_solutions(problem_list, prob_index) + print("sol", tmp) + elif args.data == "starter": + tmp = get_starter(problem_list, prob_index) + print("starter", tmp) + elif args.data in ["test", "t"]: + # test it with sols + sols = get_solutions(problem_list, prob_index) + tmp = run_test(problem_list, prob_index, test=sols[0]) + + print("results = ", tmp) + print("-2 = compile error, -1 is runtime error, False failed test, True passed test") + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/eval_apps.py b/eval_apps.py new file mode 100644 index 0000000..3aba257 --- /dev/null +++ b/eval_apps.py @@ -0,0 +1,176 @@ +""" +Run solutions from one problem. +""" + +import io +import json +import logging +import math +import numpy as np +import os +import pprint +import sys +import testing_util as test_util +import time + +# for timing debugging +from datetime import datetime, date +from tqdm import tqdm + +from typing import List + +def print_results(results, args): + res = [] + per_prob_res = [] + all_correct = [] + for index in results: + res.extend(results[index]) + per_prob_res.append(np.mean(results[index])) + all_correct.append(np.all(results[index])) + tmp_results = res + compile_errors = len(tmp_results[tmp_results==-2]) + runtime_errors = len(tmp_results[tmp_results==-1]) + failures = len(tmp_results[tmp_results==False]) + successes = len(tmp_results[tmp_results==True]) + total_testcases = len(res) + if args.debug: + print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases }") + print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}") + print(f"number of test cases run = {total_testcases}") + + print(f"Test Case Average (average accuracy over problems) = {np.mean(per_prob_res)}") + print(f"Strict Accuracy (all test cases passed / total problems) = {np.mean(all_correct)}") + + +def eval_and_save_problems(args): + with open(args.test_loc, "r") as f: + problems = sorted(json.load(f)) + + print(len(problems)) + gpt_codes = {} + gpt_bleu = {} + gpt_codebleu = {} + results = {} + codes_loc = os.path.join(args.save, f"all_codes.json") + if not os.path.exists(codes_loc): + codes_loc = os.path.join(args.save, f"{args.start}-{args.end}_codes.json") + + if os.path.exists(codes_loc): + results_loc = os.path.join(args.save, f"all_results.json") + else: + results_loc = os.path.join(args.save, f"{args.start}-{args.end}_results.json") + print(codes_loc, results_loc) + + with open(codes_loc, "r") as f: + gpt_codes = json.load(f) + + if args.index: + problems = [problems[args.index]] + else: + if args.start > len(problems) or args.start < 0: + print(f"start index {args.start} > number of problems {len(problems)}") + return + start = args.start + if args.end is None or args.end > len(problems): + end = len(problems) + else: + end = args.end + problems = problems[start:end] + + if args.stop_early: + problems = problems[:args.stop_early] + + # main eval loop + for index, problem in enumerate(tqdm(problems)): + try: + if args.debug: + print(f"\n\nproblem path = {problem}") + output_str = gpt_codes[str(index+args.start)] + except: + print("CANNOT FIND OUTPUT_STR FOR", problem) + continue + prob_path = os.path.join(args.root, problem) + + with open(os.path.join(prob_path, "solutions.json"), "r") as f: + sols = json.load(f) + + if not os.path.exists(args.save): + os.makedirs(args.save) + + res = [] + for o_idx, o in enumerate(output_str): + if args.debug: + print(f"\nTesting solution {o_idx}") + curr_res = [-2] + try: + curr_res = test_util.run_test(prob_path=prob_path, test=o, debug=args.debug) + fixed = [] + for e in curr_res: + if isinstance(e, np.ndarray): + e = e.item(0) + if isinstance(e, np.bool_): + e = bool(e) + fixed.append(e) + curr_res = fixed + if not np.all(curr_res): + print(f"Results were not all True: {curr_res}") + except Exception as e: + print(f"test framework exception = {repr(e)}{e}\n") + break + finally: + assert isinstance(curr_res, list) + res.append(curr_res) + + if args.debug: + print(f"\nHow to read results [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case") + #print(f"results = {res}") + + results[index+args.start+args.index] = res + + with open(results_loc, "w") as f: + try: + f.write(json.dumps(results)) + except Exception as e: + import pdb; pdb.set_trace() + print("didn't save problem due to {e}") + + return results + + +def main(args): + + argsdict = vars(args) + print(pprint.pformat(argsdict)) + + if args.print_results: + results = {} + codes_loc = os.path.join(args.save, f"all_codes.json") + if os.path.exists(codes_loc): + results_loc = os.path.join(args.save, f"all_results.json") + else: + results_loc = os.path.join(args.save, f"{args.start}-{args.end}_results.json") + with open(results_loc, "r") as f: + results = json.load(f) + else: + results = eval_and_save_problems(args) + + print_results(results, args) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Testing a Language Model on Python Code") + parser.add_argument("-t","--test_loc", default="../data_split/test.json", type=str, help="path to the json containing problem paths to be evaluated.") + parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.") + parser.add_argument("-s","--start", default=0, type=int) + parser.add_argument("-e","--end", default=None, type=int, help="If you want to evaluate a subset of problems specify start and ending index. File with start and ending prefix must exist typically used with batch evaluation.") + parser.add_argument("-i", "--index", default=0, type=int) + parser.add_argument("-p", "--print_results", action="store_true", help="If you have already evaluated the results and only want to print them.") + parser.add_argument("-d", "--debug", action="store_true") + parser.add_argument("--save", type=str, default="./results", help="Where the evaluated data is loaded from and results saved to.") + parser.add_argument("--stop-early", default=None, type=int) + + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/finetune_apps.sh b/finetune_apps.sh new file mode 100644 index 0000000..bd0d454 --- /dev/null +++ b/finetune_apps.sh @@ -0,0 +1,31 @@ +#! /bin/bash +./run_clm_apps.py \ + --output_dir $HOME/gpt-code-clippy-apps-2 \ + --model_name_or_path EleutherAI/gpt-neo-1.3B \ + --dataset_name ./apps.py \ + --do_train --do_eval \ + --block_size="1024" \ + --per_device_train_batch_size="2" \ + --per_device_eval_batch_size="2" \ + --preprocessing_num_workers="16" \ + --learning_rate="2e-5" \ + --warmup_steps="5000" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" \ + --weight_decay="0.1" \ + --overwrite_output_dir \ + --num_train_epochs="5" \ + --logging_steps="20" \ + --eval_steps="1000" \ + --push_to_hub="False" \ + --report_to="wandb" \ + --dtype="bfloat16" \ + --skip_memory_metrics="False" \ + --save_steps="1000" \ + --save_strategy epoch \ + --save_total_limit 2 \ + --gradient_accumulation_steps 2 \ + --adafactor \ + # --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \ + # --max_train_samples="10000" \ + # --max_eval_samples="1000" diff --git a/partitions.py b/partitions.py deleted file mode 100644 index 6088532..0000000 --- a/partitions.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for constructing PyTrees of PartitionSpecs.""" - -# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py - -import re - -from flax.core.frozen_dict import freeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.experimental import PartitionSpec as P - - -# Sentinels -_unmatched = object() - -# For specifying empty leaf dict `{}` -empty_dict = object() - - -def _match(qs, ks): - """Return True if regexes in qs match any window of strings in tuple ks.""" - # compile regexes and force complete match - qts = tuple(map(lambda x: re.compile(x + "$"), qs)) - for i in range(len(ks) - len(qs) + 1): - matches = [x.match(y) for x, y in zip(qts, ks[i:])] - if matches and all(matches): - return True - return False - - -def _replacement_rules(rules): - def replace(key, val): - for rule, replacement in rules: - if _match(rule, key): - return replacement - return val - - return replace - - -# PartitionSpec for GPTNeo -# replicate the hidden dim and shard feed-forward and head dim -def _get_partition_rules(): - return [ - # embeddings - (("transformer", "wpe", "embedding"), P("mp", None)), - (("transformer", "wte", "embedding"), P("mp", None)), - # atention - (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), - (("attention", "out_proj", "kernel"), P("mp", None)), - (("attention", "out_proj", "bias"), None), - # mlp - (("mlp", "c_fc", "kernel"), P(None, "mp")), - (("mlp", "c_fc", "bias"), P("mp")), - (("mlp", "c_proj", "kernel"), P("mp", None)), - (("mlp", "c_proj", "bias"), None), - # layer norms - ((r"ln_\d+", "bias"), None), - ((r"\d+", r"ln_\d+", "scale"), None), - (("ln_f", "bias"), None), - (("ln_f", "scale"), None), - ] - - -def set_partitions(in_dict): - rules = _get_partition_rules() - replace = _replacement_rules(rules) - initd = {k: _unmatched for k in flatten_dict(in_dict)} - result = {k: replace(k, v) for k, v in initd.items()} - assert _unmatched not in result.values(), "Incomplete partition spec." - return freeze(unflatten_dict(result)) -{"mode":"full","isActive":false} \ No newline at end of file diff --git a/run_clm_mp_apps.py b/run_clm_mp_apps.py new file mode 100755 index 0000000..c201010 --- /dev/null +++ b/run_clm_mp_apps.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pre-training/Fine-tuning the GPTNeo model for causal language modeling on a text file or a dataset using model parallelism. +""" + +import logging +import math +import os +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Optional + +import datasets +import numpy as np +from datasets import Dataset, load_dataset +from tqdm import tqdm + +import jax +import jax.numpy as jnp +import optax +import transformers +from flax.core.frozen_dict import freeze, unfreeze +from flax.training.common_utils import onehot, stack_forest +from jax.experimental.maps import mesh +from jax.experimental.pjit import pjit +from partitions import set_partitions +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoTokenizer, + FlaxAutoModelForCausalLM, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, +) +from transformers.testing_utils import CaptureLogger + + +logger = logging.getLogger(__name__) + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) + from_pt: Optional[bool] = field( + default=False, + metadata={"help": "Whether the model weights should be converted from pytorch."}, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + block_size: Optional[int] = field( + default=None, + metadata={ + "help": "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): + """ + Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. + Shuffle batches if `shuffle` is `True`. + """ + steps_per_epoch = len(dataset) // batch_size + + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + else: + batch_idx = jnp.arange(len(dataset)) + + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: jnp.array(v) for k, v in batch.items()} + yield batch + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = stack_forest(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False + ) + + if "validation" not in dataset.keys(): + dataset["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + dataset["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained config and tokenizer + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if training_args.do_train: + column_names = dataset["train"].column_names + else: + column_names = dataset["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + + if data_args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > config.max_position_embeddings: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if data_args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(data_args.block_size, tokenizer.model_max_length) + + + def tokenize_function(examples): + toks = tokenizer(examples["question"], + examples["answer"], + max_length=block_size, + padding="max_length", + truncation=True, + return_token_type_ids=True, + # return_tensors="np", + ) + labels = toks["input_ids"].copy() + toks["labels"] = labels + return toks + + + lm_datasets = dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + + if training_args.do_train: + if "train" not in lm_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = lm_datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in lm_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constant + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + steps_per_epoch = len(train_dataset) // train_batch_size + total_train_steps = steps_per_epoch * num_epochs + + # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models + # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions + # GPT2's vocab is odd, we need to resize it for fine-tuning + model = FlaxAutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), from_pt=model_args.from_pt, + ) + + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + if training_args.adafactor: + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + ) + + def get_initial_state(params): + state = optimizer.init(params) + return tuple(state), params + + # Get PartitionSpec for model params + param_spec = set_partitions(unfreeze(model.params)) + + # Get the PyTree for opt_state, we don't actually initialize the opt_state yet. + params_shapes = jax.tree_map(lambda x: x.shape, model.params) + state_shapes = jax.eval_shape(get_initial_state, params_shapes) + + # get PartitionSpec for opt_state, this is very specific to adamw + # TODO: optax returns different state for different optimizers, how can we handle this generically ? + # or maybe we don't since in our examples we just use adamw or adafactor + def get_opt_spec(x): + if isinstance(x, dict): + return param_spec + return None + + opt_state_spec, param_spec = jax.tree_map( + get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) + ) + + # pjit the get_initial_state function to shard params and init + # optimizer state in sharded way + p_get_initial_state = pjit( + get_initial_state, + in_axis_resources=None, + out_axis_resources=(opt_state_spec, param_spec), + ) + + # hack: move the inital params to CPU to free up device memory + # TODO: allow loading weights on CPU in pre-trained model + model.params = jax.tree_map(lambda x: np.asarray(x), model.params) + + # mesh defination + mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count()) + + # actually initialize the opt_state + with mesh(mesh_devices, ("dp", "mp")): + opt_state, params = p_get_initial_state(freeze(model.params)) + + # cross-entropy with z loss + def loss_fn(logits, labels, z_loss=0): + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + + shift_labels = onehot(shift_labels, shift_logits.shape[-1]) + + shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True)) + log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True)) + log_softmax = shift_logits - log_z + loss = -jnp.sum(shift_labels * log_softmax, axis=-1) + + loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss + + return loss.mean() + + # Define gradient update step fn + # TODO: try to use TrainState instead of passing params and opt_state individually + def train_step(params, opt_state, dropout_rng, batch, step): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + # TODO: mask question in loss_func + token_type_ids = batch.pop('token_type_ids') + logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss = loss_fn(logits, labels, z_loss=1.0) + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grads = grad_fn(params) + + updates, new_opt_state = optimizer.update(grads, opt_state, params) + new_params = optax.apply_updates(params, updates) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)} + return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1 + + # Define eval fn + def eval_step(input_ids, labels, params): + logits = model(input_ids=input_ids, params=params, train=False)[0] + loss = loss_fn(logits, labels) + # metrics + return {"loss": loss} + + p_train_step = pjit( + train_step, + in_axis_resources=(param_spec, opt_state_spec, None, None, None), + out_axis_resources=(param_spec, opt_state_spec, None, None, None), + donate_argnums=(0, 1), + ) + + p_eval_step = pjit( + eval_step, + in_axis_resources=(None, None, param_spec), + out_axis_resources=None, + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") + logger.info(f" Total optimization steps = {total_train_steps}") + + train_time = 0 + train_metrics = [] + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + global_step = 0 + # we are not doing 2D parallelism (yet!), this just does model parallelism + with mesh(mesh_devices, ("dp", "mp")): + for _ in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by shuffling sampling indices from the train dataset + train_metrics = [] + train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + + # train + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + + params, opt_state, dropout_rng, train_metric, global_step = p_train_step( + params, + opt_state, + dropout_rng, + batch, + global_step, + ) + train_metrics.append(train_metric) + + cur_step = global_step + + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) + + train_metrics = [] + + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) + eval_steps = len(eval_dataset) // eval_batch_size + + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + batch = next(eval_loader) + metrics = p_eval_step(batch["input_ids"], batch["labels"], params) + eval_metrics.append(metrics) + + # normalize eval metrics + eval_metrics = stack_forest(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + + try: + eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) + except OverflowError: + eval_metrics["perplexity"] = float("inf") + + logger.info( + f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}" + ) + + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(params) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + ) + + +if __name__ == "__main__": + main() diff --git a/run_clm_streaming.sh b/run_clm_streaming.sh index b0cf84b..0f1e917 100644 --- a/run_clm_streaming.sh +++ b/run_clm_streaming.sh @@ -1,16 +1,16 @@ #! /bin/bash ./run_clm_streaming_flax_v2.py \ - --output_dir $HOME/gpt-neo-125M-test \ - --model_name_or_path="EleutherAI/gpt-neo-125M" \ + --output_dir $HOME/gpt-neo-13B-test \ + --model_name_or_path EleutherAI/gpt-neo-1.3B \ --dataset_name $HOME/gpt-code-clippy/code_clippy.py \ - --data_dir /home/shared/code-clippy-dataset/merged-data \ + --data_dir /home/arto/exdata/merged-data \ --text_column_name="text" \ --do_train --do_eval \ - --block_size="2048" \ - --per_device_train_batch_size="8" \ - --per_device_eval_batch_size="16" \ + --block_size="1024" \ + --per_device_train_batch_size="1" \ + --per_device_eval_batch_size="2" \ --preprocessing_num_workers="8" \ - --learning_rate="6e-4" \ + --learning_rate="1e-4" \ --max_steps 500 \ --warmup_steps 150 \ --decay_steps 250 \ @@ -26,9 +26,9 @@ --skip_memory_metrics="False" \ --save_steps="50" \ --save_total_limit 2 \ - --gradient_accumulation_steps 8 \ + --gradient_accumulation_steps 1 \ --report_to="wandb" \ - --run_name="testing-mini" \ + --run_name="testing" \ --max_eval_samples 100 \ --save_optimizer true \ # --adafactor \