mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-10-26 09:17:45 +03:00
Added file to download licenses
This commit is contained in:
commit
d70d95082d
9
.gitmodules
vendored
9
.gitmodules
vendored
@ -1,3 +1,12 @@
|
||||
[submodule "dependency_repos/github-downloader"]
|
||||
path = dependency_repos/github-downloader
|
||||
url = https://github.com/EleutherAI/github-downloader
|
||||
[submodule "dependency_repos/apps"]
|
||||
path = dependency_repos/apps
|
||||
url = https://github.com/hendrycks/apps.git
|
||||
[submodule "dependency_repos/human-eval"]
|
||||
path = dependency_repos/human-eval
|
||||
url = https://github.com/openai/human-eval
|
||||
[submodule "dependency_repos/CodeXGLUE"]
|
||||
path = dependency_repos/CodeXGLUE
|
||||
url = https://github.com/microsoft/CodeXGLUE
|
||||
|
File diff suppressed because it is too large
Load Diff
208
EDA.ipynb
208
EDA.ipynb
@ -1,208 +0,0 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.8.10 64-bit ('.venv': venv)"
|
||||
},
|
||||
"interpreter": {
|
||||
"hash": "25d32b45eddbc1e23c07b06a2c9ff49f418e028170a37fc806346e2c2002bf83"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"from collections import Counter\n",
|
||||
"from uuid import uuid4\n",
|
||||
"\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoTokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Using the latest cached version of the module from /home/shpotes/.cache/huggingface/modules/datasets_modules/datasets/code_clippy/86b09b4a623c1c39753a8ad165e05757d9a97daf132ac71d3b6eb791e7da16dd (last modified on Fri Jul 9 22:06:59 2021) since it couldn't be found locally at $HOME/gpt-code-clippy/code_clippy.py/code_clippy.py or remotely (FileNotFoundError).\n",
|
||||
"Using the latest cached version of the module from /home/shpotes/.cache/huggingface/modules/datasets_modules/datasets/code_clippy/86b09b4a623c1c39753a8ad165e05757d9a97daf132ac71d3b6eb791e7da16dd (last modified on Fri Jul 9 22:06:59 2021) since it couldn't be found locally at $HOME/gpt-code-clippy/code_clippy.py/code_clippy.py or remotely (FileNotFoundError).\n",
|
||||
"Using custom data configuration default-668fb26707140662\n",
|
||||
"Using the latest cached version of the module from /home/shpotes/.cache/huggingface/modules/datasets_modules/datasets/code_clippy/86b09b4a623c1c39753a8ad165e05757d9a97daf132ac71d3b6eb791e7da16dd (last modified on Fri Jul 9 22:06:59 2021) since it couldn't be found locally at $HOME/gpt-code-clippy/code_clippy.py/code_clippy.py or remotely (FileNotFoundError).\n",
|
||||
"Using the latest cached version of the module from /home/shpotes/.cache/huggingface/modules/datasets_modules/datasets/code_clippy/86b09b4a623c1c39753a8ad165e05757d9a97daf132ac71d3b6eb791e7da16dd (last modified on Fri Jul 9 22:06:59 2021) since it couldn't be found locally at $HOME/gpt-code-clippy/code_clippy.py/code_clippy.py or remotely (FileNotFoundError).\n",
|
||||
"Using custom data configuration default-668fb26707140662\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_dataset = load_dataset(\n",
|
||||
" \"$HOME/gpt-code-clippy/code_clippy.py\",\n",
|
||||
" data_dir=\"/home/shared/code-clippy-dataset/merged-data\",\n",
|
||||
" streaming=True,\n",
|
||||
" split=\"train\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"eval_dataset = load_dataset(\n",
|
||||
" \"$HOME/gpt-code-clippy/code_clippy.py\",\n",
|
||||
" data_dir=\"/home/shared/code-clippy-dataset/merged-data\",\n",
|
||||
" streaming=True,\n",
|
||||
" split=\"validation\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _get_stats(example):\n",
|
||||
" num_of_tokens = map(len, tokenizer(example[\"text\"])['input_ids'])\n",
|
||||
" num_of_lines = map(lambda x: x.count(\"\\n\"), example[\"text\"])\n",
|
||||
" file_name = map(lambda x: \".\".join(x.split(\".\")[:-1]), example[\"file_name\"])\n",
|
||||
" langs = list(map(lambda x: x.split(\".\")[-1], example[\"file_name\"]))\n",
|
||||
"\n",
|
||||
" lang_map = {}\n",
|
||||
" acc_tok = []\n",
|
||||
" acc_lines = []\n",
|
||||
" acc_fnames = []\n",
|
||||
"\n",
|
||||
" for tok, lines, fname, lang in zip(num_of_tokens, num_of_lines, file_name, langs):\n",
|
||||
" if not lang in lang_map:\n",
|
||||
" lang_idx = len(acc_tok)\n",
|
||||
" lang_map[lang] = lang_idx\n",
|
||||
"\n",
|
||||
" acc_tok.append([tok])\n",
|
||||
" acc_lines.append([lines])\n",
|
||||
" acc_fnames.append([Counter({fname: 1})])\n",
|
||||
" else:\n",
|
||||
" lang_idx = lang_map[lang]\n",
|
||||
"\n",
|
||||
" acc_tok[lang_idx][0] += tok\n",
|
||||
" acc_lines[lang_idx][0] += lines\n",
|
||||
" acc_fnames[lang_idx][0].update({fname: 1})\n",
|
||||
" \n",
|
||||
" lang = [[k] for k, v in sorted(lang_map.items(), key=lambda item: item[1])]\n",
|
||||
" _id = [str(uuid4())] * len(lang)\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"ext\": lang,\n",
|
||||
" \"id\": _id,\n",
|
||||
" \"acc_num_of_tokens\": acc_tok,\n",
|
||||
" \"acc_num_of_lines\": acc_lines,\n",
|
||||
" \"acc_file_names\": acc_fnames,\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" def collapse_metrics(per_language_dataset, target_file_name):\n",
|
||||
" num_of_tokens = {}\n",
|
||||
" num_of_lines = {}\n",
|
||||
" file_names = {}\n",
|
||||
"\n",
|
||||
" observed_lang = set()\n",
|
||||
"\n",
|
||||
" for row in tqdm(per_language_dataset):\n",
|
||||
" lang = row['ext'][0]\n",
|
||||
" if lang in observed_lang:\n",
|
||||
" num_of_tokens[lang] += row['acc_num_of_tokens'][0]\n",
|
||||
" num_of_lines[lang] += row['acc_num_of_lines'][0]\n",
|
||||
" file_names[lang].update(row['acc_file_names'][0])\n",
|
||||
" else:\n",
|
||||
" num_of_tokens[lang] = row['acc_num_of_tokens'][0]\n",
|
||||
" num_of_lines[lang] = row['acc_num_of_lines'][0]\n",
|
||||
" file_names[lang] = row['acc_file_names'][0]\n",
|
||||
"\n",
|
||||
" with open(target_file_name, 'wb') as buf:\n",
|
||||
" pickle.dump({\n",
|
||||
" 'num_of_tokens': num_of_tokens,\n",
|
||||
" 'num_of_lines': num_of_lines,\n",
|
||||
" 'file_names': file_names,\n",
|
||||
" }, buf)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = train_dataset.map(_get_stats, batched=True, batch_size=50_000)\n",
|
||||
"eval_dataset = eval_dataset.map(_get_stats, batched=True, batch_size=50_000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"0it [00:00, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2708 > 2048). Running this sequence through the model will result in indexing errors\n",
|
||||
"61it [00:55, 1.36it/s]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"collapse_metrics(train_dataset, 'train_metrics.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collapse_metrics(eval_dataset, 'eval_metrics.pkl')"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
2
LICENSE
2
LICENSE
@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2021 Nathan Cooper and the amazing contributors of this project
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
50
README.md
50
README.md
@ -1,20 +1,20 @@
|
||||
# GPT-Code-Clippy (GPT-CC)
|
||||
|
||||
# Open Source GitHub Copilot for auto generating code
|
||||
## Open Source GitHub Copilot for auto generating code
|
||||
|
||||
I would like to train an open source version of the new awesome GitHub Copilot AI tool, which is based on GPT3. Similar to the awesome people behind GPT-Neo, having such an open source model would greatly help researchers understand what this type of biases and limitations this kind of code autocompletion model might have such as generating insecure code (i do research in this area and i know my team would love an open sourced version to run experiments on, i.e. try and break it 🤓)
|
||||
|
||||
## 2. Language
|
||||
## Getting the data
|
||||
|
||||
The model will be trained on different programming languages such as C, C++, java, python, etc.
|
||||
### Downloading the data
|
||||
|
||||
## 3. Model
|
||||
### Further processing the data
|
||||
|
||||
GPT-Neo
|
||||
## Finetuning the model
|
||||
|
||||
## 4. Datasets
|
||||
## Evaluating the model
|
||||
|
||||
Datasets that contain hopefully high quality source code
|
||||
## Using the model
|
||||
|
||||
Possible links to publicly available datasets include:
|
||||
- https://huggingface.co/datasets/code_search_net
|
||||
@ -28,17 +28,47 @@ I believe the standard CLM language model script would do for this.
|
||||
|
||||
We can make use of https://www.github.com/huggingface/transformers/tree/master/examples%2Fflax%2Flanguage-modeling%2Frun_clm_flax.py
|
||||
|
||||
## 6. (Optional) Challenges
|
||||
for training the scripts you can run:
|
||||
`python run_clm_streaming_flax_v2.py `
|
||||
|
||||
|
||||
## 6. Usage
|
||||
|
||||
|
||||
code for running the code generation is done by using:
|
||||
`bash run_cln_straming.sh`
|
||||
|
||||
run_cln_straming.sh contains all the hyperparameters and will be used to generate code.
|
||||
|
||||
we have also generated the code for the following languages:
|
||||
python
|
||||
javascript
|
||||
c++
|
||||
c
|
||||
java
|
||||
|
||||
We have used GPT-Neo using 13B and 27B parameter settings.
|
||||
|
||||
to run for following files:
|
||||
|
||||
13B:
|
||||
`bash run_cln_gpt_neo_13b.sh`
|
||||
|
||||
27B:
|
||||
`bash run_cln_gpt_neo_27b.sh`
|
||||
|
||||
|
||||
## 7. (Optional) Challenges
|
||||
|
||||
The data additional data may be a challenge. From what I can see in copilot, it looks to be training on entire files, not code snippets. There are file level datasets that exist but they are a few years old and i don't think they cover many programming languages. The ones I listed above have multiple languages but are only methods.
|
||||
|
||||
However, githubs API is pretty easy to use and so it would be pretty easy to create one from scratch, especially if we get some insights into how the copilot dataset was generated 🤓
|
||||
|
||||
## 7. (Optional) Desired project outcome
|
||||
## 8. (Optional) Desired project outcome
|
||||
|
||||
I'd love to have this open source model setup in a similar Visual Studio Code extension to the GitHub Copilot one. I've actually made a tutorial on doing this using the GPT-Neo model, so we could easily clean it up and release it free of charge forever because from what I've seen on Twitter the GitHub Copilot might eventually be put behind a paywall 😢.
|
||||
|
||||
## 8. (Optional) Reads
|
||||
## 9. (Optional) Reads
|
||||
|
||||
The following links can be useful to better understand the project and
|
||||
what has previously been done.
|
||||
|
31
data_processing/get_license_info.py
Normal file
31
data_processing/get_license_info.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from fastcore.script import *
|
||||
from ghapi.all import GhApi
|
||||
|
||||
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
|
||||
|
||||
|
||||
# Open issue on repo using custom title and body
|
||||
def get_license_info(owner, repo):
|
||||
api = GhApi(owner=owner, repo=repo, token=GITHUB_TOKEN)
|
||||
license = api.licenses.get_for_repo(owner=owner, repo=repo)
|
||||
return license.license.name
|
||||
|
||||
@call_parse
|
||||
def main(repos_path: Param("Path to the csv containing all of the repos", str)):
|
||||
"""
|
||||
Use pandas dataframe from the repos path to open issues in each of them.
|
||||
"""
|
||||
repos_path = Path(repos_path)
|
||||
df = pd.read_csv(repos_path)
|
||||
|
||||
# Loop through repos and get their license
|
||||
licenses = []
|
||||
for _, row in df.iterrows():
|
||||
owner, repo = row["name"].split("/")
|
||||
licenses.append(get_license_info(owner, repo))
|
||||
df["license"] = licenses
|
||||
df.to_csv(repos_path.parent/f"{repos_path.stem}_with_license.csv", index=False)
|
372
data_scripts/download_license_info.py
Normal file
372
data_scripts/download_license_info.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright maintained by EleutherAI. Originally from https://github.com/EleutherAI/github-downloader
|
||||
|
||||
# Modified from: https://github.com/ncoop57/gpt-code-clippy/blob/main/data_scripts/download_repo_text.py
|
||||
|
||||
import chardet
|
||||
import magic
|
||||
#import lm_dataformat as lmd
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
import shutil
|
||||
import csv
|
||||
import json
|
||||
from multiprocessing import cpu_count, Pool
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import subprocess
|
||||
from itertools import repeat
|
||||
import pandas as pd
|
||||
|
||||
bad_extensions = [
|
||||
'app',
|
||||
'bin',
|
||||
'bmp',
|
||||
'bz2',
|
||||
'class',
|
||||
'csv',
|
||||
'dat',
|
||||
'db',
|
||||
'dll',
|
||||
'dylib',
|
||||
'egg',
|
||||
'eot',
|
||||
'exe',
|
||||
'gif',
|
||||
'gitignore',
|
||||
'glif',
|
||||
'gradle',
|
||||
'gz',
|
||||
'ico',
|
||||
'jar',
|
||||
'jpeg',
|
||||
'jpg',
|
||||
'lo',
|
||||
'lock',
|
||||
'log',
|
||||
'mp3',
|
||||
'mp4',
|
||||
'nar',
|
||||
'o',
|
||||
'ogg',
|
||||
'otf',
|
||||
'p',
|
||||
'pdf',
|
||||
'png',
|
||||
'pickle',
|
||||
'pkl',
|
||||
'pyc',
|
||||
'pyd',
|
||||
'pyo',
|
||||
'rkt',
|
||||
'so',
|
||||
'ss',
|
||||
'svg',
|
||||
'tar',
|
||||
'tsv',
|
||||
'ttf',
|
||||
'war',
|
||||
'webm',
|
||||
'woff',
|
||||
'woff2',
|
||||
'xz',
|
||||
'zip',
|
||||
'zst'
|
||||
]
|
||||
# load programming language extensions from json file
|
||||
with open("./Programming_Languages_Extensions.json", "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
lang_exts = []
|
||||
for i in data:
|
||||
if "extensions" not in i:
|
||||
continue
|
||||
lang_exts.extend(i["extensions"])
|
||||
|
||||
mime = magic.Magic(mime=True)
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def timeout(func, args=(), kwargs={}, timeout_duration=150, default=None):
|
||||
# wrap any function in this wrapper to raise a TimeoutError after timeout_duration secs
|
||||
import signal
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
# set the timeout handler
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout_duration)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except TimeoutError:
|
||||
result = default
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def split_into_chunks(l, n):
|
||||
n = max(1, n)
|
||||
return [l[i:i + n] for i in range(0, len(l), n)]
|
||||
|
||||
|
||||
def is_digit(x):
|
||||
return x in "1234567890"
|
||||
|
||||
|
||||
def keep(x):
|
||||
# simple filters to decide whether a file is worth keeping
|
||||
num_digits = len(list(filter(is_digit, x)))
|
||||
num_newlines = len(list(filter(lambda x: x == '\n', x)))
|
||||
if num_digits / len(x) > 0.8:
|
||||
return False
|
||||
|
||||
# avg line length
|
||||
if len(x) / (num_newlines + .001) > 200:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def filter_by_stars(repo_data, n_stars):
|
||||
return [item for item in repo_data if int(item[1]) >= n_stars]
|
||||
|
||||
|
||||
def get_content(f):
|
||||
# discerns filetype with mime and reads text from file if possible
|
||||
|
||||
type = None
|
||||
try:
|
||||
enc = 'utf-8'
|
||||
type = mime.from_file(f)
|
||||
if not type.startswith('text'):
|
||||
return
|
||||
with open(f, 'rb') as fromfh:
|
||||
buf = fromfh.read()
|
||||
|
||||
buf = buf.decode('UTF-8')
|
||||
if not keep(buf):
|
||||
return
|
||||
|
||||
return buf
|
||||
except UnicodeDecodeError:
|
||||
# bad encoding, try different encoding
|
||||
try:
|
||||
enc = None
|
||||
enc = chardet.detect(buf)
|
||||
if enc['encoding'] is None:
|
||||
return
|
||||
buf = buf.decode(enc['encoding'])
|
||||
if not keep(buf):
|
||||
return
|
||||
return buf
|
||||
except UnicodeDecodeError:
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
sys.exit()
|
||||
except FileNotFoundError:
|
||||
# bad symlink
|
||||
import os.path
|
||||
if not os.path.islink(f):
|
||||
# something went horribly wrong!
|
||||
...
|
||||
|
||||
# def filter_criteria(files):
|
||||
# filtered_files = []
|
||||
# for f in files:
|
||||
# size = os.path.getsize(f)
|
||||
# if '.git' not in f and f[0] is not '.' and \
|
||||
# 'LICENSE' not in f and 'node_modules' not in f and \
|
||||
# '.min.' not in f and f.split('.')[-1] not in bad_extensions and \
|
||||
# f.split('.')[-1] in lang_exts and size
|
||||
|
||||
|
||||
# This gets the license info from each repo
|
||||
def _process_repo(repo_data, repodir):
|
||||
out = None
|
||||
# get metadata
|
||||
name, stars, lang = repo_data
|
||||
meta = {'repo_name': str(name), 'stars': stars, 'repo_language': str(lang)}
|
||||
|
||||
|
||||
try:
|
||||
for curdir, dirs, files in os.walk(repodir):
|
||||
|
||||
#size = os.path.getsize('/usr/lib/python3.8/genericpath.py')
|
||||
|
||||
# files = [curdir + '/' + f for f in files if '.git' not in f and f[
|
||||
# 0] != '.' and 'LICENSE' == f and 'node_modules' not in f and '.min.' not in f and f.split('.')[
|
||||
# -1] not in bad_extensions and os.path.getsize('C:\\Python27\\Lib\\genericpath.py')]
|
||||
|
||||
files = [curdir + '/' + f for f in files if f == 'LICENSE']
|
||||
|
||||
filenames = [f.split("/")[-1] for f in files]
|
||||
|
||||
#extensions = []
|
||||
# for f in files:
|
||||
# try:
|
||||
# extensions.append(mime.from_file(f))
|
||||
# except FileNotFoundError:
|
||||
# extensions.append("n/a")
|
||||
text_outputs = []
|
||||
for f in files:
|
||||
try:
|
||||
text_outputs.append(get_content(f))
|
||||
except TimeoutError:
|
||||
raise TimeoutError
|
||||
except:
|
||||
err = traceback.format_exc()
|
||||
print(err)
|
||||
text_outputs.append(None)
|
||||
# for each license file
|
||||
for i in range(len(files)):
|
||||
text = text_outputs[i]
|
||||
|
||||
if text is not None:
|
||||
text_lines = text.splitlines()
|
||||
license_title = text_lines[0]
|
||||
# Maybe add a standardizer of licenses here
|
||||
meta['license'] = license_title
|
||||
#meta['file_name'] = filenames[i]
|
||||
# meta['mime_type'] = extensions[i]
|
||||
if out is None:
|
||||
out = [[text, meta]]
|
||||
|
||||
else:
|
||||
out.append([text, meta])
|
||||
|
||||
shutil.rmtree(repodir, ignore_errors=True)
|
||||
except TimeoutError:
|
||||
print(f"Processing for {name} timed out")
|
||||
return out
|
||||
|
||||
|
||||
def process_repo(repo_data, repodir, processing_timeout):
|
||||
return timeout(_process_repo, args=(repo_data, repodir), timeout_duration=processing_timeout)
|
||||
|
||||
|
||||
def process_repo_list(repo_data, clone_timeout, processing_timeout):
|
||||
out = None
|
||||
try:
|
||||
name, stars, lang = repo_data
|
||||
repodir = f'./.tmp/{name.split("/")[-1]}'
|
||||
# clones master branch of repos with depth 1 (most recent commit only), ignoring any terminal prompts
|
||||
p = subprocess.Popen(
|
||||
f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 --single-branch https://github.com/{name} {repodir}',
|
||||
shell=True,
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
try:
|
||||
p.wait(clone_timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f'Git clone for {name} timed out ')
|
||||
p.kill()
|
||||
shutil.rmtree(f'{repodir}/.git', ignore_errors=True)
|
||||
# extracts text files from repo and returns them as list : [[text, metadata], ... ]
|
||||
out = process_repo(repo_data, repodir, processing_timeout=processing_timeout)
|
||||
except Exception:
|
||||
err = traceback.format_exc()
|
||||
if verbose:
|
||||
print(err)
|
||||
return out
|
||||
|
||||
|
||||
def process_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='CLI for github downloader - A tool for scraping repos as text from github')
|
||||
parser.add_argument('--n_threads', help='number of threads for parallel processing, defaults to cpu_count',
|
||||
default=-1,
|
||||
type=int)
|
||||
parser.add_argument('--n_stars', help='filter repos with less than n_stars stars',
|
||||
default=-1,
|
||||
type=int)
|
||||
parser.add_argument('--chunk_size', help='size of chunks to feed into each thread',
|
||||
default=-1,
|
||||
type=int)
|
||||
parser.add_argument('--clone_timeout', help='timeout for git clone command in seconds',
|
||||
default=150,
|
||||
type=int)
|
||||
parser.add_argument('--processing_timeout', help='timeout for processing repo to text files in seconds',
|
||||
default=150,
|
||||
type=int)
|
||||
parser.add_argument('--commit_freq', help='how often (in number of chunks) to commit the archive file',
|
||||
default=10,
|
||||
type=int)
|
||||
parser.add_argument('-v', '--verbose', help='if flag is present, print errors', action='store_true')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
args = process_args() # parse args
|
||||
verbose = args.verbose
|
||||
|
||||
# make output dirs
|
||||
if '.tmp' not in os.listdir():
|
||||
os.makedirs('.tmp')
|
||||
if 'github_data' not in os.listdir():
|
||||
os.makedirs('github_data')
|
||||
|
||||
# read repo data to a tuple (reponame, n_stars, language)
|
||||
with open('github_repositories.csv', 'r') as f:
|
||||
csv_reader = csv.reader(f)
|
||||
repo_data = list(map(tuple, csv_reader))
|
||||
|
||||
# filter by number of stars
|
||||
if args.n_stars != -1:
|
||||
repo_data = filter_by_stars(repo_data, args.n_stars)
|
||||
repo_data.sort()
|
||||
|
||||
random.seed(420)
|
||||
random.shuffle(repo_data)
|
||||
|
||||
n_threads = cpu_count() * 3 if args.n_threads == -1 else args.n_threads
|
||||
chunk_size = n_threads * 3 if args.chunk_size == -1 else args.chunk_size
|
||||
|
||||
assert n_threads != 0
|
||||
|
||||
# do work
|
||||
repo_chunks = split_into_chunks(repo_data, chunk_size)
|
||||
archive_name = 'github_data'
|
||||
# ar = lmd.Archive(archive_name)
|
||||
pool = Pool(n_threads)
|
||||
new_repo_df = {'repo_name': [], 'stars': [], 'repo_language': [], 'license': [], 'full_license_text':[]}
|
||||
pbar = tqdm(repo_chunks, total=len(repo_chunks))
|
||||
success_hist = []
|
||||
for count, chunk in enumerate(pbar):
|
||||
repos_out = pool.starmap(process_repo_list,
|
||||
zip(chunk, repeat(args.clone_timeout), repeat(args.processing_timeout)))
|
||||
not_none = 0
|
||||
none = 0
|
||||
for repo in repos_out:
|
||||
if repo is not None:
|
||||
not_none += 1
|
||||
for f in repo:
|
||||
new_repo_df['repo_name'].append(str(f[1]['repo_name']))
|
||||
new_repo_df['stars'].append(f[1]['stars'])
|
||||
new_repo_df['repo_language'].append(str(f[1]['repo_language']))
|
||||
new_repo_df['license'].append(f[1]['license'])
|
||||
new_repo_df['full_license_text'].append(str(f[0]))
|
||||
|
||||
#ar.add_data(f[0], meta=f[1])
|
||||
else:
|
||||
none += 1
|
||||
|
||||
# remove any leftover files
|
||||
subprocess.Popen("rm -rfv .tmp && mkdir .tmp", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
# if count % args.commit_freq == 0:
|
||||
# ar.commit()
|
||||
success_hist.append((not_none / len(repos_out)) * 100)
|
||||
success_rate = sum(success_hist) / len(success_hist)
|
||||
pbar.set_postfix({"Success Rate": success_rate, 'Data Size': len(new_repo_df)})
|
||||
#pbar.set_postfix({"Success Rate": success_rate})
|
||||
new_repo_df = pd.DataFrame.from_dict(new_repo_df)
|
||||
new_repo_df.to_csv('github_repositories_with_licenses.csv')
|
||||
new_repo_df.to_feather('github_repositories_with_licenses.feather')
|
||||
|
||||
# ar.commit() # final commit
|
1
dependency_repos/CodeXGLUE
Submodule
1
dependency_repos/CodeXGLUE
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 3e7bfe6dc4a88534c7803ce1bd8d1733c1d16888
|
1
dependency_repos/apps
Submodule
1
dependency_repos/apps
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit f834ca7d7405935376aabb5830edd0c42635824e
|
1
dependency_repos/human-eval
Submodule
1
dependency_repos/human-eval
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 463c980b59e818ace59f6f9803cd92c749ceae61
|
200
docs/DATASHEET.md
Normal file
200
docs/DATASHEET.md
Normal file
@ -0,0 +1,200 @@
|
||||
---
|
||||
YAML tags:
|
||||
- copy-paste the tags obtained with the tagging app: https://github.com/huggingface/datasets-tagging
|
||||
---
|
||||
|
||||
# Dataset Card Creation Guide
|
||||
|
||||
## Table of Contents
|
||||
- [Dataset Card Creation Guide](#dataset-card-creation-guide)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Dataset Description](#dataset-description)
|
||||
- [Dataset Summary](#dataset-summary)
|
||||
- [Supported Tasks and Leaderboards](#supported-tasks-and-leaderboards)
|
||||
- [Languages](#languages)
|
||||
- [Dataset Structure](#dataset-structure)
|
||||
- [Data Instances](#data-instances)
|
||||
- [Data Fields](#data-fields)
|
||||
- [Data Splits](#data-splits)
|
||||
- [Dataset Creation](#dataset-creation)
|
||||
- [Curation Rationale](#curation-rationale)
|
||||
- [Source Data](#source-data)
|
||||
- [Initial Data Collection and Normalization](#initial-data-collection-and-normalization)
|
||||
- [Who are the source language producers?](#who-are-the-source-language-producers)
|
||||
- [Annotations](#annotations)
|
||||
- [Annotation process](#annotation-process)
|
||||
- [Who are the annotators?](#who-are-the-annotators)
|
||||
- [Personal and Sensitive Information](#personal-and-sensitive-information)
|
||||
- [Considerations for Using the Data](#considerations-for-using-the-data)
|
||||
- [Social Impact of Dataset](#social-impact-of-dataset)
|
||||
- [Discussion of Biases](#discussion-of-biases)
|
||||
- [Other Known Limitations](#other-known-limitations)
|
||||
- [Additional Information](#additional-information)
|
||||
- [Dataset Curators](#dataset-curators)
|
||||
- [Licensing Information](#licensing-information)
|
||||
- [Citation Information](#citation-information)
|
||||
- [Contributions](#contributions)
|
||||
|
||||
## Dataset Description
|
||||
|
||||
- **Homepage:** [Add homepage URL here if available (unless it's a GitHub repository)]()
|
||||
- **Repository:** [If the dataset is hosted on github or has a github homepage, add URL here]()
|
||||
- **Paper:** [If the dataset was introduced by a paper or there was a paper written describing the dataset, add URL here (landing page for Arxiv paper preferred)]()
|
||||
- **Leaderboard:** [If the dataset supports an active leaderboard, add link here]()
|
||||
- **Point of Contact:** [If known, name and email of at least one person the reader can contact for questions about the dataset.]()
|
||||
|
||||
### Dataset Summary
|
||||
|
||||
Briefly summarize the dataset, its intended use and the supported tasks. Give an overview of how and why the dataset was created. The summary should explicitly mention the languages present in the dataset (possibly in broad terms, e.g. *translations between several pairs of European languages*), and describe the domain, topic, or genre covered.
|
||||
|
||||
### Supported Tasks and Leaderboards
|
||||
|
||||
For each of the tasks tagged for this dataset, give a brief description of the tag, metrics, and suggested models (with a link to their HuggingFace implementation if available). Give a similar description of tasks that were not covered by the structured tag set (repace the `task-category-tag` with an appropriate `other:other-task-name`).
|
||||
|
||||
- `task-category-tag`: The dataset can be used to train a model for [TASK NAME], which consists in [TASK DESCRIPTION]. Success on this task is typically measured by achieving a *high/low* [metric name](https://huggingface.co/metrics/metric_name). The ([model name](https://huggingface.co/model_name) or [model class](https://huggingface.co/transformers/model_doc/model_class.html)) model currently achieves the following score. *[IF A LEADERBOARD IS AVAILABLE]:* This task has an active leaderboard which can be found at [leaderboard url]() and ranks models based on [metric name](https://huggingface.co/metrics/metric_name) while also reporting [other metric name](https://huggingface.co/metrics/other_metric_name).
|
||||
|
||||
### Languages
|
||||
|
||||
Provide a brief overview of the languages represented in the dataset. Describe relevant details about specifics of the language such as whether it is social media text, African American English,...
|
||||
|
||||
When relevant, please provide [BCP-47 codes](https://tools.ietf.org/html/bcp47), which consist of a [primary language subtag](https://tools.ietf.org/html/bcp47#section-2.2.1), with a [script subtag](https://tools.ietf.org/html/bcp47#section-2.2.3) and/or [region subtag](https://tools.ietf.org/html/bcp47#section-2.2.4) if available.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
### Data Instances
|
||||
|
||||
Provide an JSON-formatted example and brief description of a typical instance in the dataset. If available, provide a link to further examples.
|
||||
|
||||
```
|
||||
{
|
||||
'example_field': ...,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Provide any additional information that is not covered in the other sections about the data here. In particular describe any relationships between data points and if these relationships are made explicit.
|
||||
|
||||
### Data Fields
|
||||
|
||||
List and describe the fields present in the dataset. Mention their data type, and whether they are used as input or output in any of the tasks the dataset currently supports. If the data has span indices, describe their attributes, such as whether they are at the character level or word level, whether they are contiguous or not, etc. If the datasets contains example IDs, state whether they have an inherent meaning, such as a mapping to other datasets or pointing to relationships between data points.
|
||||
|
||||
- `example_field`: description of `example_field`
|
||||
|
||||
Note that the descriptions can be initialized with the **Show Markdown Data Fields** output of the [tagging app](https://github.com/huggingface/datasets-tagging), you will then only need to refine the generated descriptions.
|
||||
|
||||
### Data Splits
|
||||
|
||||
Describe and name the splits in the dataset if there are more than one.
|
||||
|
||||
Describe any criteria for splitting the data, if used. If their are differences between the splits (e.g. if the training annotations are machine-generated and the dev and test ones are created by humans, or if different numbers of annotators contributed to each example), describe them here.
|
||||
|
||||
Provide the sizes of each split. As appropriate, provide any descriptive statistics for the features, such as average length. For example:
|
||||
|
||||
| | Tain | Valid | Test |
|
||||
| ----- | ------ | ----- | ---- |
|
||||
| Input Sentences | | | |
|
||||
| Average Sentence Length | | | |
|
||||
|
||||
## Dataset Creation
|
||||
|
||||
### Curation Rationale
|
||||
|
||||
What need motivated the creation of this dataset? What are some of the reasons underlying the major choices involved in putting it together?
|
||||
|
||||
### Source Data
|
||||
|
||||
This section describes the source data (e.g. news text and headlines, social media posts, translated sentences,...)
|
||||
|
||||
#### Initial Data Collection and Normalization
|
||||
|
||||
Describe the data collection process. Describe any criteria for data selection or filtering. List any key words or search terms used. If possible, include runtime information for the collection process.
|
||||
|
||||
If data was collected from other pre-existing datasets, link to source here and to their [Hugging Face version](https://huggingface.co/datasets/dataset_name).
|
||||
|
||||
If the data was modified or normalized after being collected (e.g. if the data is word-tokenized), describe the process and the tools used.
|
||||
|
||||
#### Who are the source language producers?
|
||||
|
||||
State whether the data was produced by humans or machine generated. Describe the people or systems who originally created the data.
|
||||
|
||||
If available, include self-reported demographic or identity information for the source data creators, but avoid inferring this information. Instead state that this information is unknown. See [Larson 2017](https://www.aclweb.org/anthology/W17-1601.pdf) for using identity categories as a variables, particularly gender.
|
||||
|
||||
Describe the conditions under which the data was created (for example, if the producers were crowdworkers, state what platform was used, or if the data was found, what website the data was found on). If compensation was provided, include that information here.
|
||||
|
||||
Describe other people represented or mentioned in the data. Where possible, link to references for the information.
|
||||
|
||||
### Annotations
|
||||
|
||||
If the dataset contains annotations which are not part of the initial data collection, describe them in the following paragraphs.
|
||||
|
||||
#### Annotation process
|
||||
|
||||
If applicable, describe the annotation process and any tools used, or state otherwise. Describe the amount of data annotated, if not all. Describe or reference annotation guidelines provided to the annotators. If available, provide interannotator statistics. Describe any annotation validation processes.
|
||||
|
||||
#### Who are the annotators?
|
||||
|
||||
If annotations were collected for the source data (such as class labels or syntactic parses), state whether the annotations were produced by humans or machine generated.
|
||||
|
||||
Describe the people or systems who originally created the annotations and their selection criteria if applicable.
|
||||
|
||||
If available, include self-reported demographic or identity information for the annotators, but avoid inferring this information. Instead state that this information is unknown. See [Larson 2017](https://www.aclweb.org/anthology/W17-1601.pdf) for using identity categories as a variables, particularly gender.
|
||||
|
||||
Describe the conditions under which the data was annotated (for example, if the annotators were crowdworkers, state what platform was used, or if the data was found, what website the data was found on). If compensation was provided, include that information here.
|
||||
|
||||
### Personal and Sensitive Information
|
||||
|
||||
State whether the dataset uses identity categories and, if so, how the information is used. Describe where this information comes from (i.e. self-reporting, collecting from profiles, inferring, etc.). See [Larson 2017](https://www.aclweb.org/anthology/W17-1601.pdf) for using identity categories as a variables, particularly gender. State whether the data is linked to individuals and whether those individuals can be identified in the dataset, either directly or indirectly (i.e., in combination with other data).
|
||||
|
||||
State whether the dataset contains other data that might be considered sensitive (e.g., data that reveals racial or ethnic origins, sexual orientations, religious beliefs, political opinions or union memberships, or locations; financial or health data; biometric or genetic data; forms of government identification, such as social security numbers; criminal history).
|
||||
|
||||
If efforts were made to anonymize the data, describe the anonymization process.
|
||||
|
||||
## Considerations for Using the Data
|
||||
|
||||
### Social Impact of Dataset
|
||||
|
||||
Please discuss some of the ways you believe the use of this dataset will impact society.
|
||||
|
||||
The statement should include both positive outlooks, such as outlining how technologies developed through its use may improve people's lives, and discuss the accompanying risks. These risks may range from making important decisions more opaque to people who are affected by the technology, to reinforcing existing harmful biases (whose specifics should be discussed in the next section), among other considerations.
|
||||
|
||||
Also describe in this section if the proposed dataset contains a low-resource or under-represented language. If this is the case or if this task has any impact on underserved communities, please elaborate here.
|
||||
|
||||
### Discussion of Biases
|
||||
|
||||
Provide descriptions of specific biases that are likely to be reflected in the data, and state whether any steps were taken to reduce their impact.
|
||||
|
||||
For Wikipedia text, see for example [Dinan et al 2020 on biases in Wikipedia (esp. Table 1)](https://arxiv.org/abs/2005.00614), or [Blodgett et al 2020](https://www.aclweb.org/anthology/2020.acl-main.485/) for a more general discussion of the topic.
|
||||
|
||||
If analyses have been run quantifying these biases, please add brief summaries and links to the studies here.
|
||||
|
||||
### Other Known Limitations
|
||||
|
||||
If studies of the datasets have outlined other limitations of the dataset, such as annotation artifacts, please outline and cite them here.
|
||||
|
||||
## Additional Information
|
||||
|
||||
### Dataset Curators
|
||||
|
||||
List the people involved in collecting the dataset and their affiliation(s). If funding information is known, include it here.
|
||||
|
||||
### Licensing Information
|
||||
|
||||
Provide the license and link to the license webpage if available.
|
||||
|
||||
### Citation Information
|
||||
|
||||
Provide the [BibTex](http://www.bibtex.org/)-formatted reference for the dataset. For example:
|
||||
```
|
||||
@article{article_id,
|
||||
author = {Author List},
|
||||
title = {Dataset Paper Title},
|
||||
journal = {Publication Venue},
|
||||
year = {2525}
|
||||
}
|
||||
```
|
||||
|
||||
If the dataset has a [DOI](https://www.doi.org/), please provide it here.
|
||||
|
||||
### Contributions
|
||||
|
||||
Thanks to [@github-username](https://github.com/<github-username>) for adding this dataset.
|
98
docs/MODELCARD.md
Normal file
98
docs/MODELCARD.md
Normal file
@ -0,0 +1,98 @@
|
||||
---
|
||||
language:
|
||||
- ru
|
||||
- en
|
||||
thumbnail: https://raw.githubusercontent.com/JetRunner/BERT-of-Theseus/master/bert-of-theseus.png
|
||||
tags:
|
||||
- translation
|
||||
- fsmt
|
||||
license: Apache 2.0
|
||||
datasets:
|
||||
- wmt19
|
||||
metrics:
|
||||
- bleu
|
||||
- sacrebleu
|
||||
---
|
||||
|
||||
# MyModel
|
||||
|
||||
## Model description
|
||||
|
||||
This is a ported version of [fairseq wmt19 transformer](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md) for {src_lang}-{tgt_lang}.
|
||||
|
||||
For more details, please see, [Facebook FAIR's WMT19 News Translation Task Submission](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
The abbreviation FSMT stands for FairSeqMachineTranslation
|
||||
|
||||
All four models are available:
|
||||
|
||||
* [wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru)
|
||||
* [wmt19-ru-en](https://huggingface.co/facebook/wmt19-ru-en)
|
||||
* [wmt19-en-de](https://huggingface.co/facebook/wmt19-en-de)
|
||||
* [wmt19-de-en](https://huggingface.co/facebook/wmt19-de-en)
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
from transformers.tokenization_fsmt import FSMTTokenizer
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
mname = "facebook/wmt19-ru-en"
|
||||
tokenizer = FSMTTokenizer.from_pretrained(mname)
|
||||
model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||||
|
||||
input = "Машинное обучение - это здорово, не так ли?"
|
||||
input_ids = tokenizer.encode(input, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(decoded) # Machine learning is great, isn't it?
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
- The original (and this ported model) doesn't seem to handle well inputs with repeated sub-phrases, [content gets truncated](https://discuss.huggingface.co/t/issues-with-translating-inputs-containing-repeated-phrases/981)
|
||||
|
||||
|
||||
## Training data
|
||||
|
||||
Pretrained weights were left identical to the original model released by fairseq. For more details, please, see the [paper](https://arxiv.org/abs/1907.06616).
|
||||
|
||||
|
||||
## Training procedure
|
||||
|
||||
|
||||
## Eval results
|
||||
|
||||
pair | fairseq | transformers
|
||||
-------|---------|----------
|
||||
ru-en | [41.3](http://matrix.statmt.org/matrix/output/1907?run_id=6937) | 39.20
|
||||
|
||||
|
||||
The score was calculated using this code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
export PAIR=ru-en
|
||||
export DATA_DIR=data/$PAIR
|
||||
export SAVE_DIR=data/$PAIR
|
||||
export BS=8
|
||||
export NUM_BEAMS=15
|
||||
mkdir -p $DATA_DIR
|
||||
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||||
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||||
echo $PAIR
|
||||
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||||
```
|
||||
|
||||
### BibTeX entry and citation info
|
||||
|
||||
```bibtex
|
||||
@inproceedings{...,
|
||||
year={2020},
|
||||
title={Facebook FAIR's WMT19 News Translation Task Submission},
|
||||
author={Ng, Nathan and Yee, Kyra and Baevski, Alexei and Ott, Myle and Auli, Michael and Edunov, Sergey},
|
||||
booktitle={Proc. of WMT},
|
||||
}
|
||||
```
|
814
evaluation/data_processing.ipynb
Normal file
814
evaluation/data_processing.ipynb
Normal file
File diff suppressed because one or more lines are too long
172
evaluation/evaluate.py
Normal file
172
evaluation/evaluate.py
Normal file
@ -0,0 +1,172 @@
|
||||
import json
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
# import apps.eval.reident
|
||||
|
||||
# from apps_utils.generate_gpt_codes import generate_prompt
|
||||
# from apps_utils.test_one_solution import eval_and_save_problems
|
||||
from datasets import load_dataset, load_metric
|
||||
from fastcore.script import *
|
||||
from human_eval.data import HUMAN_EVAL, write_jsonl, read_problems
|
||||
from human_eval.evaluation import evaluate_functional_correctness
|
||||
from pathlib import Path
|
||||
# from metrics.extrinsic_eval import compute_metrics
|
||||
from subprocess import check_output
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
FlaxGPTNeoForCausalLM,
|
||||
)
|
||||
|
||||
bleu = load_metric("sacrebleu")
|
||||
# tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
|
||||
# model = AutoModelWithLMHead.from_pretrained(
|
||||
# "/home/nathan/gpt-code-clippy/data/APPS/models/1.5B"
|
||||
# )
|
||||
|
||||
MAX_TOKENs = 1_024
|
||||
model_name_or_path = "EleutherAI/gpt-neo-125M"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path, padding_side="left", pad_token="<|endoftext|>"
|
||||
)
|
||||
model = FlaxGPTNeoForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
pad_token_id=50256,
|
||||
)
|
||||
|
||||
|
||||
def generate_text(prompt):
|
||||
inputs = tokenizer(prompt, return_tensors="jax")#.to("cuda")
|
||||
output_seq = model.generate(input_ids=inputs.input_ids, max_length=1_024)
|
||||
|
||||
return tokenizer.decode(output_seq["sequences"][0])
|
||||
|
||||
|
||||
def _eval_concode(path):
|
||||
# TODO: format input to model same as App and OpenAI HumanEval datasets are formatted
|
||||
data = load_dataset("json", data_files=str(path / "test.json"))["train"]
|
||||
predictions = [[]]
|
||||
references = []
|
||||
for example in data:
|
||||
output = generate_text(example["nl"])
|
||||
predictions[0].append(output.split(" "))
|
||||
references.append(example["code"].split(" "))
|
||||
results = compute_metrics(predictions, references)
|
||||
print(f"Bleu score for Concode dataset: {results}")
|
||||
|
||||
|
||||
def _eval_apps(path):
|
||||
gpt_codes = {}
|
||||
prob_paths = sorted(path.glob("*/"))
|
||||
# map prob_paths to strings and save as a json file
|
||||
str_paths = [str(p) for p in prob_paths]
|
||||
with open(path / "test.json", "w") as f:
|
||||
json.dump(str_paths, f)
|
||||
for index, prob_path in enumerate(prob_paths[:2]):
|
||||
test_case_path = prob_path / "input_output.json"
|
||||
prompt_path = prob_path / "question.txt"
|
||||
starter_path = prob_path / "starter_code.py"
|
||||
solutions_path = prob_path / "solutions.json"
|
||||
if not starter_path.exists():
|
||||
starter_path = None
|
||||
if not test_case_path.exists() or not prompt_path.exists():
|
||||
continue
|
||||
prompt = generate_prompt(
|
||||
Args(),
|
||||
test_case_path,
|
||||
prompt_path,
|
||||
solutions_path,
|
||||
tokenizer,
|
||||
starter_path=starter_path,
|
||||
)
|
||||
output = generate_text(prompt)
|
||||
print(output)
|
||||
# print(output)
|
||||
gpt_codes[index] = output
|
||||
# print(output)
|
||||
|
||||
with open(path.parent / "all_codes.json", "w") as f:
|
||||
json.dump(gpt_codes, f)
|
||||
|
||||
eval_and_save_problems(path, path.parent)
|
||||
|
||||
# execute bash command to run eval script
|
||||
# results = check_output(
|
||||
# [
|
||||
# # python3 test_one_solution.py -t /path/to/apps/test --save /path/to/save_dir --print_results
|
||||
# "python",
|
||||
# "./apps_utils/test_one_solution.py",
|
||||
# "-t",
|
||||
# str(path),
|
||||
# "--save",
|
||||
# str(path.parent),
|
||||
# "--print_results",
|
||||
# ]
|
||||
# ).decode("utf-8")
|
||||
|
||||
|
||||
# test_case_path = os.path.join(prob_path, "input_output.json")
|
||||
# prompt_path = os.path.join(prob_path, "question.txt")
|
||||
# starter_path = os.path.join(prob_path, "starter_code.py")
|
||||
# solutions_path = os.path.join(prob_path, "solutions.json")
|
||||
# generate_prompt(args, test_case_path, prompt_path, solutions_path, tokenizer, starter_path=None)
|
||||
|
||||
|
||||
def _eval_human_eval(path):
|
||||
|
||||
problems = read_problems(str(path))
|
||||
num_samples_per_task = 1
|
||||
samples = [
|
||||
dict(
|
||||
task_id=task_id,
|
||||
completion=generate_text(problems[task_id]["prompt"]),
|
||||
)
|
||||
for task_id in problems
|
||||
for _ in range(num_samples_per_task)
|
||||
]
|
||||
write_jsonl("human_eval.jsonl", samples)
|
||||
# execute bash command to run eval script
|
||||
results = evaluate_functional_correctness("human_eval.jsonl", [1], 4, 3.0, str(path))
|
||||
# results = check_output(
|
||||
# [
|
||||
# "python",
|
||||
# path / "evaluate_functional_correctness.py",
|
||||
# "human_eval.jsonl",
|
||||
# ]
|
||||
# ).decode("utf-8")
|
||||
|
||||
print(results)
|
||||
|
||||
|
||||
@call_parse
|
||||
def main(
|
||||
concode_path: Param("Path to the concode data in CodeXGLUE", str),
|
||||
apps_path: Param("Path to the the App dataset", str),
|
||||
human_eval_path: Param("Path to the human eval dataset", str),
|
||||
):
|
||||
concode_path = Path(concode_path)
|
||||
apps_path = Path(apps_path)
|
||||
human_eval_path = Path(human_eval_path)
|
||||
# _eval_concode(concode_path)
|
||||
_eval_human_eval(human_eval_path)
|
||||
# _eval_apps(apps_path)
|
||||
# dataset = load_dataset("json", data_files=str(concode_path / "test.json"))
|
||||
# print(dataset)
|
||||
# results = bleu.compute(predictions=predictions, references=references)
|
||||
# print(list(results.keys()))
|
||||
# print(round(results["score"], 1))
|
||||
|
||||
|
||||
# problems = read_problems()
|
||||
# print(problems)
|
||||
# num_samples_per_task = 200
|
||||
# samples = [
|
||||
# dict(
|
||||
# task_id=task_id,
|
||||
# completion=generate_text(problems[task_id]["prompt"]),
|
||||
# )
|
||||
# for task_id in problems[:1]
|
||||
# for _ in range(num_samples_per_task)
|
||||
# ]
|
||||
# write_jsonl("human_eval.jsonl", samples)
|
102
evaluation/evaluation/apps_utils/generate_gpt_codes.py
Normal file
102
evaluation/evaluation/apps_utils/generate_gpt_codes.py
Normal file
@ -0,0 +1,102 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Dan Hendrycks and contributors.
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
|
||||
"""
|
||||
Run a tranined model to generate Python code.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import time
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
from apps_utils.reindent import run as run_reindent
|
||||
|
||||
# for timing and debugging
|
||||
from datetime import datetime, date
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def reindent_code(codestr):
|
||||
"""
|
||||
Given code string, reindent it in the same way that the
|
||||
Github dataset was indented
|
||||
"""
|
||||
codestr = io.StringIO(codestr)
|
||||
ret = io.StringIO()
|
||||
|
||||
run_reindent(
|
||||
codestr,
|
||||
ret,
|
||||
config={
|
||||
"dry-run": False,
|
||||
"help": False,
|
||||
"to": 10,
|
||||
"from": -1,
|
||||
"tabs": True,
|
||||
"encoding": "utf-8",
|
||||
"is-tabs": False,
|
||||
"tabsize": 10,
|
||||
"all-tabs": False,
|
||||
},
|
||||
)
|
||||
|
||||
return ret.getvalue()
|
||||
|
||||
|
||||
def generate_prompt(
|
||||
test_case_path, prompt_path, solutions_path, tokenizer, starter_path=None
|
||||
):
|
||||
_input = "\nQUESTION:\n"
|
||||
with open(prompt_path, "r") as f:
|
||||
data = f.readlines()
|
||||
data = "".join(data)
|
||||
_input += data
|
||||
if starter_path != None:
|
||||
with open(starter_path, "r") as f:
|
||||
data = f.readlines()
|
||||
data = "".join(data)
|
||||
data = "\n" + data # + "\n"
|
||||
_input += data
|
||||
else:
|
||||
# _input += "\n\n"
|
||||
pass
|
||||
|
||||
with open(test_case_path, "r") as f:
|
||||
data = json.load(f)
|
||||
if not data.get("fn_name"):
|
||||
_input += "\nUse Standard Input format" # \n"
|
||||
else:
|
||||
_input += "\nUse Call-Based format" # \n"
|
||||
|
||||
_input += "\nANSWER:\n"
|
||||
|
||||
return _input
|
227
evaluation/evaluation/apps_utils/reindent.py
Normal file
227
evaluation/evaluation/apps_utils/reindent.py
Normal file
@ -0,0 +1,227 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Dan Hendrycks and contributors.
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
|
||||
"""
|
||||
Reindent files.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
import getopt
|
||||
import codecs
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
|
||||
|
||||
def _find_indentation(line, config):
|
||||
if len(line) and line[0] in (" ", "\t") and not line.isspace():
|
||||
if line[0] == "\t":
|
||||
config['is-tabs'] = True
|
||||
# Find indentation
|
||||
i = 0
|
||||
for char in list(line):
|
||||
if char not in (" ", "\t"):
|
||||
break
|
||||
i += 1
|
||||
config["from"] = i
|
||||
|
||||
|
||||
def find_indentation(line, config):
|
||||
# Find indentation level used in file
|
||||
if config['from'] < 0:
|
||||
_find_indentation(line, config)
|
||||
|
||||
if config['from'] >= 0:
|
||||
# Set old indent
|
||||
indent = " " if not config['is-tabs'] else "\t"
|
||||
indent = indent * config['from']
|
||||
|
||||
# Set new indent
|
||||
newindent = " " if not config['tabs'] else "\t"
|
||||
if not config['tabs']:
|
||||
newindent = newindent * config['to']
|
||||
|
||||
return indent, newindent
|
||||
|
||||
# Continue to the next line, indentation not found
|
||||
return False
|
||||
|
||||
|
||||
def replace_inline_tabs(content, config):
|
||||
newcontent = ""
|
||||
imagined_i = 0
|
||||
for i in range(0, len(content)):
|
||||
char = content[i]
|
||||
if char == '\t':
|
||||
spaces = config['tabsize']-(imagined_i % config['tabsize'])
|
||||
newcontent += " " * spaces
|
||||
imagined_i += spaces
|
||||
else:
|
||||
newcontent += char
|
||||
imagined_i += 1
|
||||
return newcontent
|
||||
|
||||
|
||||
def run(fd_in, fd_out, config):
|
||||
from reindent_4_spaces import Reindenter
|
||||
import io
|
||||
|
||||
inter = io.StringIO()
|
||||
ri = Reindenter(fd_in)
|
||||
ri.run()
|
||||
ri.write(inter)
|
||||
fd_in = inter
|
||||
fd_in.seek(0)
|
||||
|
||||
while True:
|
||||
line = fd_in.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.rstrip('\r\n')
|
||||
|
||||
# Find indentation style used in file if not set
|
||||
if config['from'] < 0:
|
||||
indent = find_indentation(line, config)
|
||||
if not indent:
|
||||
print(line, file=fd_out)
|
||||
continue
|
||||
indent, newindent = indent
|
||||
|
||||
# Find current indentation level
|
||||
level = 0
|
||||
while True:
|
||||
whitespace = line[:len(indent) * (level + 1)]
|
||||
if whitespace == indent * (level + 1):
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
content = line[len(indent) * level:]
|
||||
if config['all-tabs']:
|
||||
content = replace_inline_tabs(content, config)
|
||||
|
||||
line = (newindent * level) + content
|
||||
print(line, file=fd_out)
|
||||
# print(config)
|
||||
|
||||
|
||||
def run_files(filenames, config):
|
||||
for filename in filenames:
|
||||
with codecs.open(filename, encoding=config['encoding']) as fd_in:
|
||||
if config['dry-run']:
|
||||
print("Filename: %s" % filename)
|
||||
fd_out = sys.stdout
|
||||
else:
|
||||
fd_out = tempfile.NamedTemporaryFile(mode='wb', delete=False)
|
||||
fd_out.close()
|
||||
fd_out = codecs.open(fd_out.name, "wb", encoding=config['encoding'])
|
||||
|
||||
run(fd_in, fd_out, config)
|
||||
|
||||
if not config["dry-run"]:
|
||||
fd_out.close()
|
||||
shutil.copy(fd_out.name, filename)
|
||||
os.remove(fd_out.name)
|
||||
|
||||
|
||||
def main(args):
|
||||
config = {
|
||||
"dry-run": False,
|
||||
"help": False,
|
||||
"to": 4,
|
||||
"from": -1,
|
||||
"tabs": False,
|
||||
"encoding": "utf-8",
|
||||
"is-tabs": False,
|
||||
"tabsize": 4,
|
||||
"all-tabs": False
|
||||
}
|
||||
possible_args = {
|
||||
"d": "dry-run",
|
||||
"h": "help",
|
||||
"t:": "to=",
|
||||
"f:": "from=",
|
||||
"n": "tabs",
|
||||
"e:": "encoding=",
|
||||
"s:": "tabsize=",
|
||||
"a": "all-tabs",
|
||||
}
|
||||
optlist, filenames = getopt.getopt(
|
||||
args[1:],
|
||||
"".join(possible_args.keys()),
|
||||
possible_args.values()
|
||||
)
|
||||
|
||||
shortargs, longargs = [], []
|
||||
for shortarg in possible_args:
|
||||
shortargs.append(shortarg.rstrip(":"))
|
||||
longargs.append(possible_args[shortarg].rstrip("="))
|
||||
|
||||
for opt, val in optlist:
|
||||
opt = opt.lstrip("-")
|
||||
if opt in shortargs:
|
||||
opt = longargs[shortargs.index(opt)]
|
||||
if isinstance(config[opt], bool):
|
||||
config[opt] = True
|
||||
elif isinstance(config[opt], int):
|
||||
config[opt] = int(val)
|
||||
else:
|
||||
config[opt] = val
|
||||
|
||||
if config['help']:
|
||||
help = """
|
||||
Usage: %s [options] filename(s)
|
||||
Options:
|
||||
-h, --help Show this message
|
||||
-d, --dry-run Don't save anything, just print
|
||||
the result
|
||||
-t <n>, --to <n> Convert to this number of spaces
|
||||
(default: 4)
|
||||
-f <n>, --from <n> Convert from this number of spaces
|
||||
(default: auto-detect, will also
|
||||
detect tabs)
|
||||
-n, --tabs Don't convert indentation to spaces,
|
||||
convert to tabs instead. -t and
|
||||
--to will have no effect.
|
||||
-a, --all-tabs Also convert tabs used for alignment
|
||||
in the code (Warning: will replace
|
||||
all tabs in the file, even if inside
|
||||
a string)
|
||||
-s <n>, --tabsize <n> Set how many spaces one tab is
|
||||
(only has an effect on -a, default: 4)
|
||||
-e <s>, --encoding <s> Open files with specified encoding
|
||||
(default: utf-8)
|
||||
""" % args[0]
|
||||
|
||||
# Also removes 8 leading spaces to remove our indentation
|
||||
print("\n".join([x[8:] for x in help[1:].split("\n")]))
|
||||
sys.exit(0)
|
||||
|
||||
if filenames:
|
||||
run_files(filenames, config)
|
||||
else:
|
||||
run(sys.stdin, sys.stdout, config)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv)
|
158
evaluation/evaluation/apps_utils/test_one_solution.py
Normal file
158
evaluation/evaluation/apps_utils/test_one_solution.py
Normal file
@ -0,0 +1,158 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Dan Hendrycks and contributors.
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""
|
||||
Run solutions from one problem.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import apps_utils.testing_util as test_util
|
||||
import time
|
||||
|
||||
# for timing debugging
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def print_results(results, args):
|
||||
res = []
|
||||
per_prob_res = []
|
||||
all_correct = []
|
||||
for index in results:
|
||||
res.extend(results[index])
|
||||
per_prob_res.append(np.mean(results[index]))
|
||||
all_correct.append(np.all(results[index]))
|
||||
tmp_results = res
|
||||
compile_errors = len(tmp_results[tmp_results == -2])
|
||||
runtime_errors = len(tmp_results[tmp_results == -1])
|
||||
failures = len(tmp_results[tmp_results == False])
|
||||
successes = len(tmp_results[tmp_results == True])
|
||||
total_testcases = len(res)
|
||||
if args.debug:
|
||||
print(
|
||||
f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases }"
|
||||
)
|
||||
print(
|
||||
f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}"
|
||||
)
|
||||
print(f"number of test cases run = {total_testcases}")
|
||||
|
||||
print(
|
||||
f"Test Case Average (average accuracy over problems) = {np.mean(per_prob_res)}"
|
||||
)
|
||||
print(
|
||||
f"Strict Accuracy (all test cases passed / total problems) = {np.mean(all_correct)}"
|
||||
)
|
||||
|
||||
|
||||
def eval_and_save_problems(test_loc, save):
|
||||
test_path = Path(test_loc)
|
||||
problems = list(test_path.glob("*/"))
|
||||
|
||||
print(len(problems))
|
||||
gpt_codes = {}
|
||||
gpt_bleu = {}
|
||||
gpt_codebleu = {}
|
||||
results = {}
|
||||
codes_loc = os.path.join(save, f"all_codes.json")
|
||||
# if not os.path.exists(codes_loc):
|
||||
# codes_loc = os.path.join(args.save, f"{args.start}-{args.end}_codes.json")
|
||||
|
||||
if os.path.exists(codes_loc):
|
||||
results_loc = os.path.join(save, f"all_results.json")
|
||||
print(codes_loc, results_loc)
|
||||
|
||||
with open(codes_loc, "r") as f:
|
||||
gpt_codes = json.load(f)
|
||||
|
||||
# main eval loop
|
||||
for index, problem in enumerate(tqdm(problems[:2])):
|
||||
try:
|
||||
# if args.debug:
|
||||
# print(f"\n\nproblem path = {problem}")
|
||||
output_str = gpt_codes[str(index)]
|
||||
except:
|
||||
print("CANNOT FIND OUTPUT_STR FOR", problem)
|
||||
continue
|
||||
prob_path = problem # os.path.join(args.root, problem)
|
||||
|
||||
# with open(os.path.join(prob_path, "solutions.json"), "r") as f:
|
||||
# sols = json.load(f)
|
||||
|
||||
if not os.path.exists(save):
|
||||
os.makedirs(save)
|
||||
|
||||
res = []
|
||||
# for o_idx, o in enumerate(output_str):
|
||||
# print(o)
|
||||
# if args.debug:
|
||||
# print(f"\nTesting solution {o_idx}")
|
||||
curr_res = [-2]
|
||||
try:
|
||||
curr_res = test_util.run_test(
|
||||
prob_path=prob_path, test=output_str, debug=False # args.debug
|
||||
)
|
||||
fixed = []
|
||||
for e in curr_res:
|
||||
if isinstance(e, np.ndarray):
|
||||
e = e.item(0)
|
||||
if isinstance(e, np.bool_):
|
||||
e = bool(e)
|
||||
fixed.append(e)
|
||||
curr_res = fixed
|
||||
if not np.all(curr_res):
|
||||
print(f"Results were not all True: {curr_res}")
|
||||
except Exception as e:
|
||||
print(f"test framework exception = {repr(e)}{e}\n")
|
||||
break
|
||||
finally:
|
||||
assert isinstance(curr_res, list)
|
||||
res.append(curr_res)
|
||||
|
||||
# if args.debug:
|
||||
# print(
|
||||
# f"\nHow to read results [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case"
|
||||
# )
|
||||
# print(f"results = {res}")
|
||||
|
||||
results[index] = res
|
||||
|
||||
with open(results_loc, "w") as f:
|
||||
try:
|
||||
f.write(json.dumps(results))
|
||||
except Exception as e:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
print("didn't save problem due to {e}")
|
||||
|
||||
return results
|
544
evaluation/evaluation/apps_utils/testing_util.py
Normal file
544
evaluation/evaluation/apps_utils/testing_util.py
Normal file
@ -0,0 +1,544 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import faulthandler
|
||||
|
||||
# used for debugging to time steps
|
||||
from datetime import datetime
|
||||
|
||||
# to run the solution files we're using a timing based approach
|
||||
import signal
|
||||
|
||||
import numpy as np
|
||||
# for capturing the stdout
|
||||
from io import StringIO
|
||||
from typing import get_type_hints
|
||||
from typing import List, Tuple
|
||||
# used for testing the code that reads from input
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from pyext import RuntimeModule
|
||||
|
||||
from enum import Enum
|
||||
class CODE_TYPE(Enum):
|
||||
call_based = 0
|
||||
standard_input = 1
|
||||
|
||||
# stuff for setting up signal timer
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
def timeout_handler(signum, frame):
|
||||
print("alarm went off")
|
||||
#return
|
||||
raise TimeoutException
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
timeout = 4 # seconds
|
||||
|
||||
# used to capture stdout as a list
|
||||
# from https://stackoverflow.com/a/16571630/6416660
|
||||
# alternative use redirect_stdout() from contextlib
|
||||
class Capturing(list):
|
||||
def __enter__(self):
|
||||
self._stdout = sys.stdout
|
||||
sys.stdout = self._stringio = StringIO()
|
||||
# Make closing the StringIO a no-op
|
||||
self._stringio.close = lambda x: 1
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
self.extend(self._stringio.getvalue().splitlines())
|
||||
del self._stringio # free up some memory
|
||||
sys.stdout = self._stdout
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Utility for testing code generation.")
|
||||
parser.add_argument("-v", "--verbosity-level", action="store", type=int,
|
||||
help="")
|
||||
parser.add_argument("-s", "--source", type=str, default="leetcode",
|
||||
choices=["leetcode", "atcoder", "codewars",],
|
||||
help="which data source to gather from.")
|
||||
parser.add_argument("-d", "--data", type=str, default="question",
|
||||
choices=["question", "q", "solutions", "sol", "s", "starter", "tests", "t"],
|
||||
help="which type of data to receive.")
|
||||
parser.add_argument("-n", "--number", type=int, default=0,
|
||||
help="which problem to query.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_valid_problems(data_dir="leetcode"):
|
||||
# these are unnecessary atm
|
||||
if data_dir == "leetcode":
|
||||
root = os.path.join(args.source, "data")
|
||||
elif data_dir == "atcoder":
|
||||
pass
|
||||
|
||||
root = os.path.join(data_dir, "data")
|
||||
if os.path.exists(os.path.join(data_dir, "valid_problems.json")):
|
||||
with open(os.path.join(data_dir, "valid_problems.json"), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
# after we compute it once let's save it and load that instead
|
||||
# TODO determine if might be better to reload each time
|
||||
tmp = os.listdir(root)
|
||||
valid_probs = []
|
||||
for folder in tmp:
|
||||
prob_path = os.path.join(root, folder)
|
||||
files = os.listdir(prob_path)
|
||||
#TODO add more validity checks
|
||||
if "input_output.json" in files or "sols.json" in files:
|
||||
valid_probs.append(prob_path)
|
||||
valid_probs = sorted(valid_probs)
|
||||
#with open(os.path.join(args.source,"valid_problems.json"), "w") as f:
|
||||
# json.dump(valid_probs, f)
|
||||
return valid_probs
|
||||
|
||||
|
||||
def get_question(problem_list, prob_index):
|
||||
root = problem_list[prob_index]
|
||||
#print("get q", root)
|
||||
if os.path.exists(os.path.join(root, "question.txt")):
|
||||
with open(os.path.join(root, "question.txt")) as f:
|
||||
question = f.readlines()
|
||||
else:
|
||||
print("question prompt not found")
|
||||
question = ""
|
||||
question = "".join(question)
|
||||
return question
|
||||
|
||||
|
||||
def get_solutions(problem_list, prob_index):
|
||||
root = problem_list[prob_index]
|
||||
if os.path.exists(os.path.join(root, "solutions.json")):
|
||||
with open(os.path.join(root, "solutions.json")) as f:
|
||||
sols = json.load(f)
|
||||
return sols
|
||||
|
||||
|
||||
def run_test(prob_path:str=None, problem_list:List[str]=None, prob_index:int=None,
|
||||
test:str=None, debug:bool=False):
|
||||
"""
|
||||
if test is not None it'll try to run the code.
|
||||
otherwise it'll just return an input and output pair.
|
||||
"""
|
||||
if prob_path is None and problem_list is None:
|
||||
print("please provide either prob_path or problem_list")
|
||||
exit()
|
||||
|
||||
if debug:
|
||||
print(f"start = {datetime.now().time()}")
|
||||
if prob_path is not None:
|
||||
root = prob_path
|
||||
elif problem_list is not None:
|
||||
root = problem_list[prob_index]
|
||||
|
||||
if os.path.exists(os.path.join(root, "input_output.json")):
|
||||
with open(os.path.join(root, "input_output.json")) as f:
|
||||
in_outs = json.load(f)
|
||||
if debug:
|
||||
print(f"test cases json = {in_outs['inputs']} {in_outs['outputs']}")
|
||||
|
||||
if in_outs.get("fn_name") is None:
|
||||
which_type = CODE_TYPE.standard_input # Standard input
|
||||
method_name = None
|
||||
else:
|
||||
which_type = CODE_TYPE.call_based # Call-based
|
||||
method_name = in_outs["fn_name"]
|
||||
if debug:
|
||||
print(f"loaded json = {datetime.now().time()}")
|
||||
|
||||
#else:
|
||||
# continue
|
||||
if test is None:
|
||||
return in_outs
|
||||
elif test is not None:
|
||||
results = []
|
||||
sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
|
||||
if debug:
|
||||
print(f"loading test code = {datetime.now().time()}")
|
||||
|
||||
if which_type == CODE_TYPE.call_based:
|
||||
sol += test
|
||||
if debug: # or True:
|
||||
print(f"sol = {sol}")
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
||||
if "class Solution" not in test:
|
||||
tmp = tmp_sol
|
||||
else:
|
||||
tmp = tmp_sol.Solution()
|
||||
signal.alarm(0)
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
print(f"type 0 compilation error = {e}")
|
||||
results.append(-2)
|
||||
return results
|
||||
signal.alarm(0)
|
||||
|
||||
elif which_type == CODE_TYPE.standard_input:
|
||||
# sol
|
||||
tmp_test = test.split("\n")
|
||||
|
||||
new_test = []
|
||||
for x in tmp_test:
|
||||
if (not x.startswith("from ")) and (not x.startswith("import ")):
|
||||
new_test.append("\t" + x + "\n")
|
||||
else:
|
||||
new_test.append(x + "\n")
|
||||
tmp_test = new_test
|
||||
|
||||
new_test = ""
|
||||
started = False
|
||||
for i in tmp_test:
|
||||
if i.startswith("\t") and not started:
|
||||
new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
|
||||
new_test += "def code():\n"
|
||||
new_test += i
|
||||
started = True
|
||||
elif started and ((i.startswith("from ")) or (i.startswith("import "))):
|
||||
new_test += "\t" + i
|
||||
else:
|
||||
new_test += i
|
||||
tmp_test = new_test
|
||||
|
||||
sol += tmp_test
|
||||
if debug:
|
||||
print(f"sol = {sol}")
|
||||
# print(f"{o}")
|
||||
method_name = "code"
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
||||
tmp = tmp_sol
|
||||
signal.alarm(0)
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
print(f"type 1 compilation error = {e}")
|
||||
results.append(-2)
|
||||
return results
|
||||
signal.alarm(0)
|
||||
if debug:
|
||||
print(f"get method = {datetime.now().time()}")
|
||||
|
||||
try:
|
||||
method = getattr(tmp, method_name) # get_attr second arg must be str
|
||||
except:
|
||||
signal.alarm(0)
|
||||
e = sys.exc_info()
|
||||
print(f"unable to get function error = {e}")
|
||||
return results
|
||||
|
||||
for index, inputs in enumerate(in_outs["inputs"]):
|
||||
# JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
|
||||
try:
|
||||
if isinstance(inputs[0], dict):
|
||||
inputs = [{int(k): v for k,v in inputs[0].items()}]
|
||||
except:
|
||||
True
|
||||
try:
|
||||
if isinstance(in_outs["outputs"][index], dict):
|
||||
in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
|
||||
except:
|
||||
True
|
||||
try:
|
||||
if isinstance(in_outs["outputs"][index][0], dict):
|
||||
in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
|
||||
except:
|
||||
True
|
||||
|
||||
if debug:
|
||||
print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
|
||||
if which_type == CODE_TYPE.call_based: # Call-based
|
||||
signal.alarm(timeout)
|
||||
faulthandler.enable()
|
||||
try:
|
||||
# print("------------")
|
||||
# print(inputs)
|
||||
output = method(*inputs)
|
||||
|
||||
# ground truth sequences are not tuples
|
||||
if isinstance(output, tuple):
|
||||
output = list(output)
|
||||
|
||||
tmp_result = output == in_outs["outputs"][index]
|
||||
if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
|
||||
tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
|
||||
|
||||
# ground truth sequences are not tuples
|
||||
try:
|
||||
if isinstance(output[0], tuple):
|
||||
tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
|
||||
except:
|
||||
True
|
||||
results.append(tmp_result)
|
||||
|
||||
# reset the alarm
|
||||
signal.alarm(0)
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
faulthandler.disable()
|
||||
print(f"Standard input runtime error or time limit exceeded error = {e}")
|
||||
results.append(-1)
|
||||
continue
|
||||
faulthandler.disable()
|
||||
signal.alarm(0)
|
||||
if debug:
|
||||
print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
elif which_type == CODE_TYPE.standard_input: # Standard input
|
||||
faulthandler.enable()
|
||||
signal.alarm(timeout)
|
||||
passed = False
|
||||
|
||||
if isinstance(inputs, list):
|
||||
inputs = "\n".join(inputs)
|
||||
if isinstance(in_outs['outputs'][index], list):
|
||||
in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
|
||||
|
||||
with Capturing() as output:
|
||||
try:
|
||||
call_method(method, inputs)
|
||||
# reset the alarm
|
||||
signal.alarm(0)
|
||||
passed = True
|
||||
except Exception as e:
|
||||
# runtime error or took too long
|
||||
signal.alarm(0)
|
||||
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
|
||||
results.append(-1)
|
||||
signal.alarm(0)
|
||||
|
||||
if not passed:
|
||||
if debug:
|
||||
nl = "\n"
|
||||
if not isinstance(inputs, list):
|
||||
print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
else:
|
||||
print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
continue
|
||||
|
||||
if passed and debug:
|
||||
print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
|
||||
|
||||
if custom_compare_(output, in_outs['outputs'][index]):
|
||||
tmp_result = True
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
# ground truth sequences are expressed as lists not tuples
|
||||
if isinstance(output, tuple):
|
||||
output = list(output)
|
||||
|
||||
tmp_result = False
|
||||
try:
|
||||
tmp_result = (output == [in_outs["outputs"][index]])
|
||||
if isinstance(in_outs["outputs"][index], list):
|
||||
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
||||
if isinstance(output[0], str):
|
||||
tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
|
||||
except Exception as e:
|
||||
print(f"Failed check1 exception = {e}")
|
||||
pass
|
||||
|
||||
if tmp_result == True:
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
# try one more time without \n
|
||||
if isinstance(in_outs["outputs"][index], list):
|
||||
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
||||
in_outs["outputs"][index][tmp_index] = i.split("\n")
|
||||
in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
|
||||
else:
|
||||
in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
|
||||
in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
|
||||
in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
|
||||
|
||||
try:
|
||||
tmp_result = (output == [in_outs["outputs"][index]])
|
||||
if isinstance(in_outs["outputs"][index], list):
|
||||
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
||||
except Exception as e:
|
||||
print(f"Failed check2 exception = {e}")
|
||||
pass
|
||||
|
||||
if tmp_result == True:
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
# try by converting the output into a split up list too
|
||||
if isinstance(output, list):
|
||||
output = list(filter(len, output))
|
||||
|
||||
if debug:
|
||||
nl = "\n"
|
||||
if not isinstance(inputs, list):
|
||||
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
else:
|
||||
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
|
||||
if tmp_result == True:
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
try:
|
||||
tmp_result = (output == [in_outs["outputs"][index]])
|
||||
if isinstance(in_outs["outputs"][index], list):
|
||||
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
||||
except Exception as e:
|
||||
print(f"Failed check3 exception = {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
output_float = [float(e) for e in output]
|
||||
gt_float = [float(e) for e in in_outs['outputs'][index]]
|
||||
tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
if isinstance(output[0], list):
|
||||
output_float = [float(e) for e in output[0]]
|
||||
gt_float = [float(e) for e in in_outs['outputs'][index][0]]
|
||||
tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if tmp_result == True:
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
# try by converting the stuff into split up list
|
||||
if isinstance(in_outs["outputs"][index], list):
|
||||
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
||||
in_outs["outputs"][index][tmp_index] = set(i.split())
|
||||
else:
|
||||
in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
|
||||
|
||||
try:
|
||||
tmp_result = (output == in_outs["outputs"][index])
|
||||
except Exception as e:
|
||||
print(f"Failed check4 exception = {e}")
|
||||
continue
|
||||
|
||||
if tmp_result == True:
|
||||
results.append(tmp_result)
|
||||
continue
|
||||
|
||||
# try by converting the output into a split up list too
|
||||
if isinstance(output, list):
|
||||
for tmp_index, i in enumerate(output):
|
||||
output[tmp_index] = i.split()
|
||||
output = list(filter(len, output))
|
||||
for tmp_index, i in enumerate(output):
|
||||
output[tmp_index] = set(i)
|
||||
else:
|
||||
output = output.split()
|
||||
output = list(filter(len, output))
|
||||
output = set(output)
|
||||
|
||||
try:
|
||||
tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
|
||||
except Exception as e:
|
||||
print(f"Failed check5 exception = {e}")
|
||||
|
||||
|
||||
# if they are all numbers, round so that similar numbers are treated as identical
|
||||
try:
|
||||
tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
|
||||
set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
|
||||
except Exception as e:
|
||||
print(f"Failed check6 exception = {e}")
|
||||
|
||||
if tmp_result == True and debug:
|
||||
print("PASSED")
|
||||
|
||||
results.append(tmp_result)
|
||||
|
||||
if debug:
|
||||
nl = "\n"
|
||||
if not isinstance(inputs, list):
|
||||
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
else:
|
||||
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def custom_compare_(output, ground_truth):
|
||||
|
||||
if isinstance(output, list):
|
||||
output_1 = "\n".join(output)
|
||||
if stripped_string_compare(output_1, ground_truth):
|
||||
return True
|
||||
|
||||
if isinstance(output, list):
|
||||
output_2 = [o.lstrip().rstrip() for o in output]
|
||||
output_2 = "\n".join(output_2)
|
||||
if stripped_string_compare(output_2, ground_truth):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def stripped_string_compare(s1, s2):
|
||||
s1 = s1.lstrip().rstrip()
|
||||
s2 = s2.lstrip().rstrip()
|
||||
return s1 == s2
|
||||
|
||||
def call_method(method, inputs):
|
||||
|
||||
if isinstance(inputs, list):
|
||||
inputs = "\n".join(inputs)
|
||||
|
||||
inputs_line_iterator = iter(inputs.split("\n"))
|
||||
|
||||
# sys.setrecursionlimit(10000)
|
||||
|
||||
# @patch('builtins.input', side_effect=inputs.split("\n"))
|
||||
@patch('builtins.open', mock_open(read_data=inputs))
|
||||
@patch('sys.stdin', StringIO(inputs))
|
||||
@patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
|
||||
@patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
|
||||
@patch('sys.stdin.read', lambda *args: inputs)
|
||||
# @patch('sys.stdout.write', print)
|
||||
def _inner_call_method(_method):
|
||||
try:
|
||||
return _method()
|
||||
except SystemExit as e:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
return _inner_call_method(method)
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
problem_list = sorted(get_valid_problems(args.source))
|
||||
print(f"number of problems = {len(problem_list)}")
|
||||
prob_index = args.number
|
||||
print(f"problem is {problem_list[prob_index]}")
|
||||
|
||||
# This checks it correctly loaded. remove this later
|
||||
assert prob_index < len(problem_list)
|
||||
|
||||
if args.data == "q" or args.data == "question":
|
||||
tmp = get_question(problem_list, prob_index)
|
||||
print("q", tmp)
|
||||
elif args.data in ["solutions", "sol", "s",]:
|
||||
tmp = get_solutions(problem_list, prob_index)
|
||||
print("sol", tmp)
|
||||
elif args.data == "starter":
|
||||
tmp = get_starter(problem_list, prob_index)
|
||||
print("starter", tmp)
|
||||
elif args.data in ["test", "t"]:
|
||||
# test it with sols
|
||||
sols = get_solutions(problem_list, prob_index)
|
||||
tmp = run_test(problem_list, prob_index, test=sols[0])
|
||||
|
||||
print("results = ", tmp)
|
||||
print("-2 = compile error, -1 is runtime error, False failed test, True passed test")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
4
evaluation/evaluation/code_search_net.py
Normal file
4
evaluation/evaluation/code_search_net.py
Normal file
@ -0,0 +1,4 @@
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("code_x_glue_ct_code_to_text", "go")
|
||||
print(dataset)
|
23
evaluation/evaluation/concode.py
Normal file
23
evaluation/evaluation/concode.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pandas as pd
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from fastcore.script import *
|
||||
from pathlib import Path
|
||||
|
||||
bleu = load_metric("sacrebleu")
|
||||
|
||||
predictions = ["hello there kenobi", "foo bar foobar"]
|
||||
references = [
|
||||
["hello there general kenobi"],
|
||||
["foo bar foobar"], # , "hello there !"], # , "foo bar foobar"],
|
||||
]
|
||||
|
||||
|
||||
@call_parse
|
||||
def main(concode_path: Param("Path to the concode data in CodeXGLUE", str)):
|
||||
concode_path = Path(concode_path)
|
||||
dataset = load_dataset("json", data_files=str(concode_path / "test.json"))
|
||||
print(dataset)
|
||||
results = bleu.compute(predictions=predictions, references=references)
|
||||
print(list(results.keys()))
|
||||
print(round(results["score"], 1))
|
164
evaluation/evaluation/human_eval.jsonl
Normal file
164
evaluation/evaluation/human_eval.jsonl
Normal file
@ -0,0 +1,164 @@
|
||||
{"task_id": "HumanEval/0", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/1", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/2", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/3", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/4", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/5", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/6", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/7", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/8", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/9", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/10", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/11", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/12", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/13", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/14", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/15", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/16", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/17", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/18", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/19", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/20", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/21", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/22", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/23", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/24", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/25", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/26", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/27", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/28", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/29", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/30", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/31", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/32", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/33", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/34", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/35", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/36", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/37", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/38", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/39", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/40", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/41", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/42", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/43", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/44", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/45", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/46", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/47", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/48", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/49", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/50", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/51", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/52", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/53", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/54", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/55", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/56", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/57", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/58", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/59", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/60", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/61", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/62", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/63", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/64", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/65", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/66", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/67", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/68", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/69", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/70", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/71", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/72", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/73", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/74", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/75", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/76", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/77", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/78", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/79", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/80", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/81", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/82", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/83", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/84", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/85", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/86", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/87", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/88", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/89", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/90", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/91", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/92", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/93", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/94", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/95", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/96", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/97", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/98", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/99", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/100", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/101", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/102", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/103", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/104", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/105", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/106", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/107", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/108", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/109", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/110", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/111", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/112", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/113", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/114", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/115", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/116", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/117", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/118", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/119", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/120", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/121", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/122", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/123", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/124", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/125", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/126", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/127", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/128", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/129", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/130", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/131", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/132", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/133", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/134", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/135", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/136", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/137", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/138", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/139", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/140", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/141", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/142", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/143", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/144", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/145", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/146", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/147", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/148", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/149", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/150", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/151", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/152", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/153", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/154", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/155", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/156", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/157", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/158", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/159", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/160", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/161", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/162", "completion": "a = b \n y = a + 1"}
|
||||
{"task_id": "HumanEval/163", "completion": "a = b \n y = a + 1"}
|
164
evaluation/evaluation/human_eval.jsonl_results.jsonl
Normal file
164
evaluation/evaluation/human_eval.jsonl_results.jsonl
Normal file
@ -0,0 +1,164 @@
|
||||
{"task_id": "HumanEval/0", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/1", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/2", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/3", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/4", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/5", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/6", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/7", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/8", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/9", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/10", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/11", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/12", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/13", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/14", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 10)", "passed": false}
|
||||
{"task_id": "HumanEval/15", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/16", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/17", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/18", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/19", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/20", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/21", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/22", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/23", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/24", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 9)", "passed": false}
|
||||
{"task_id": "HumanEval/25", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/26", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/27", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 9)", "passed": false}
|
||||
{"task_id": "HumanEval/28", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/29", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/30", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/31", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/32", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 25)", "passed": false}
|
||||
{"task_id": "HumanEval/33", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/34", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 9)", "passed": false}
|
||||
{"task_id": "HumanEval/35", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/36", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/37", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/38", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 19)", "passed": false}
|
||||
{"task_id": "HumanEval/39", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/40", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/41", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/42", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/43", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/44", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/45", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 9)", "passed": false}
|
||||
{"task_id": "HumanEval/46", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 19)", "passed": false}
|
||||
{"task_id": "HumanEval/47", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/48", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/49", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/50", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/51", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/52", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/53", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/54", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/55", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/56", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/57", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/58", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/59", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/60", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/61", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/62", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/63", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/64", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 19)", "passed": false}
|
||||
{"task_id": "HumanEval/65", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/66", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/67", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/68", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 37)", "passed": false}
|
||||
{"task_id": "HumanEval/69", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/70", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/71", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/72", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/73", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/74", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/75", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/76", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/77", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/78", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/79", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/80", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/81", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 28)", "passed": false}
|
||||
{"task_id": "HumanEval/82", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/83", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 8)", "passed": false}
|
||||
{"task_id": "HumanEval/84", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/85", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 10)", "passed": false}
|
||||
{"task_id": "HumanEval/86", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/87", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 23)", "passed": false}
|
||||
{"task_id": "HumanEval/88", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/89", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/90", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/91", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/92", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/93", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/94", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/95", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/96", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/97", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/98", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/99", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 21)", "passed": false}
|
||||
{"task_id": "HumanEval/100", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/101", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/102", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/103", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/104", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/105", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 25)", "passed": false}
|
||||
{"task_id": "HumanEval/106", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/107", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 26)", "passed": false}
|
||||
{"task_id": "HumanEval/108", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/109", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 30)", "passed": false}
|
||||
{"task_id": "HumanEval/110", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/111", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/112", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/113", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/114", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/115", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 38)", "passed": false}
|
||||
{"task_id": "HumanEval/116", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/117", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/118", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/119", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/120", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 28)", "passed": false}
|
||||
{"task_id": "HumanEval/121", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/122", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/123", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/124", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 23)", "passed": false}
|
||||
{"task_id": "HumanEval/125", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/126", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 19)", "passed": false}
|
||||
{"task_id": "HumanEval/127", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 23)", "passed": false}
|
||||
{"task_id": "HumanEval/128", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/129", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 33)", "passed": false}
|
||||
{"task_id": "HumanEval/130", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/131", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/132", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/133", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/134", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/135", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/136", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/137", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/138", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 10)", "passed": false}
|
||||
{"task_id": "HumanEval/139", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/140", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/141", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/142", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 16)", "passed": false}
|
||||
{"task_id": "HumanEval/143", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 23)", "passed": false}
|
||||
{"task_id": "HumanEval/144", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/145", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/146", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 11)", "passed": false}
|
||||
{"task_id": "HumanEval/147", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 17)", "passed": false}
|
||||
{"task_id": "HumanEval/148", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 19)", "passed": false}
|
||||
{"task_id": "HumanEval/149", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/150", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 12)", "passed": false}
|
||||
{"task_id": "HumanEval/151", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 15)", "passed": false}
|
||||
{"task_id": "HumanEval/152", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 18)", "passed": false}
|
||||
{"task_id": "HumanEval/153", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 20)", "passed": false}
|
||||
{"task_id": "HumanEval/154", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/155", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 10)", "passed": false}
|
||||
{"task_id": "HumanEval/156", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/157", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/158", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
||||
{"task_id": "HumanEval/159", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 32)", "passed": false}
|
||||
{"task_id": "HumanEval/160", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 28)", "passed": false}
|
||||
{"task_id": "HumanEval/161", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 14)", "passed": false}
|
||||
{"task_id": "HumanEval/162", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 10)", "passed": false}
|
||||
{"task_id": "HumanEval/163", "completion": "a = b \n y = a + 1", "result": "failed: unexpected indent (<string>, line 13)", "passed": false}
|
0
evaluation/evaluation/human_eval_bench.py
Normal file
0
evaluation/evaluation/human_eval_bench.py
Normal file
133
evaluation/evaluation/metrics/bleu.py
Normal file
133
evaluation/evaluation/metrics/bleu.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# The following code is taken from CodeXGlue Repository - https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/evaluator/CodeBLEU/bleu.py
|
||||
|
||||
|
||||
"""Python implementation of BLEU and smooth-BLEU.
|
||||
|
||||
This module provides a Python implementation of BLEU and smooth-BLEU.
|
||||
Smooth BLEU is computed following the method outlined in the paper:
|
||||
Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
|
||||
evaluation metrics for machine translation. COLING 2004.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i : i + order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=True):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus, translation_corpus):
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
||||
for order in range(1, max_order + 1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order - 1] += possible_matches
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = (matches_by_order[i] + 1.0) / (
|
||||
possible_matches_by_order[i] + 1.0
|
||||
)
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (
|
||||
float(matches_by_order[i]) / possible_matches_by_order[i]
|
||||
)
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - 1.0 / ratio)
|
||||
bleu = geo_mean * bp
|
||||
bleu_score_dict = {
|
||||
"bleu": bleu,
|
||||
"precision": precisions,
|
||||
"bp": bp,
|
||||
"ratio": ratio,
|
||||
"trans_len": translation_length,
|
||||
"ref_len": reference_length,
|
||||
}
|
||||
return bleu_score_dict # (bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||
|
||||
|
||||
def bleu_test_case():
|
||||
"""A simple functionality test case to evaluate BLEU"""
|
||||
generated = [[["a", "=", "b", "\n", "y", "=", "a", "+", "1"]]]
|
||||
reference = [["a", "=", "b", "\n", "print", "a"]]
|
||||
score_dict = compute_bleu(generated, reference, smooth=False)
|
||||
return score_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
score_dict = bleu_test_case()
|
||||
print(score_dict)
|
46
evaluation/evaluation/metrics/extrinsic_eval.py
Normal file
46
evaluation/evaluation/metrics/extrinsic_eval.py
Normal file
@ -0,0 +1,46 @@
|
||||
from metrics.bleu import compute_bleu
|
||||
|
||||
|
||||
def compute_exact_match(references, generated) -> float:
|
||||
"""
|
||||
Computes Exact Match Accuracy.
|
||||
args:
|
||||
reference: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
returns:
|
||||
exact_match_accuracy : Float
|
||||
"""
|
||||
assert (
|
||||
len(references[0]) == len(generated),
|
||||
"Number of Samples should be equal in References and Synthesized Outputs..",
|
||||
)
|
||||
exact_match_count = 0.0
|
||||
for gen, ref in zip(generated, references[0]):
|
||||
if gen == ref:
|
||||
exact_match_count += 1
|
||||
exact_match_acc = exact_match_count / len(generated)
|
||||
return exact_match_acc
|
||||
|
||||
|
||||
def compute_metrics(references, generated) -> dict:
|
||||
"""
|
||||
Calculates various metrics and returns the calculated dict of these matrics.
|
||||
args:
|
||||
reference: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
returns:
|
||||
A dicitonary with different metrics intact.
|
||||
"""
|
||||
metrics_dict = {
|
||||
"smoothed_bleu_4": None,
|
||||
"bleu_4": None,
|
||||
"exact_match_acc": None,
|
||||
} # Update as in new metrics are computed.
|
||||
metrics_dict["smoothed_bleu_4"] = compute_bleu(references, generated, smooth=True)
|
||||
metrics_dict["bleu_4"] = compute_bleu(references, generated, smooth=False)
|
||||
metrics_dict["exact_match_acc"] = compute_exact_match(references, generated)
|
||||
return metrics_dict
|
BIN
evaluation/metrics/.DS_Store
vendored
Normal file
BIN
evaluation/metrics/.DS_Store
vendored
Normal file
Binary file not shown.
133
evaluation/metrics/bleu.py
Normal file
133
evaluation/metrics/bleu.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# The following code is taken from CodeXGlue Repository - https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/evaluator/CodeBLEU/bleu.py
|
||||
|
||||
|
||||
"""Python implementation of BLEU and smooth-BLEU.
|
||||
|
||||
This module provides a Python implementation of BLEU and smooth-BLEU.
|
||||
Smooth BLEU is computed following the method outlined in the paper:
|
||||
Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
|
||||
evaluation metrics for machine translation. COLING 2004.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i : i + order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=True):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus, translation_corpus):
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
||||
for order in range(1, max_order + 1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order - 1] += possible_matches
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = (matches_by_order[i] + 1.0) / (
|
||||
possible_matches_by_order[i] + 1.0
|
||||
)
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (
|
||||
float(matches_by_order[i]) / possible_matches_by_order[i]
|
||||
)
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - 1.0 / ratio)
|
||||
bleu = geo_mean * bp
|
||||
bleu_score_dict = {
|
||||
"bleu": bleu,
|
||||
"precision": precisions,
|
||||
"bp": bp,
|
||||
"ratio": ratio,
|
||||
"trans_len": translation_length,
|
||||
"ref_len": reference_length,
|
||||
}
|
||||
return bleu_score_dict # (bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||
|
||||
|
||||
def bleu_test_case():
|
||||
"""A simple functionality test case to evaluate BLEU"""
|
||||
generated = [[["a", "=", "b", "\n", "y", "=", "a", "+", "1"]]]
|
||||
reference = [["a", "=", "b", "\n", "print", "a"]]
|
||||
score_dict = compute_bleu(generated, reference, smooth=False)
|
||||
return score_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
score_dict = bleu_test_case()
|
||||
print(score_dict)
|
23
evaluation/metrics/extrinsic_eval.py
Normal file
23
evaluation/metrics/extrinsic_eval.py
Normal file
@ -0,0 +1,23 @@
|
||||
from metrics.bleu import compute_bleu
|
||||
from metrics.parse_check import check_parse
|
||||
|
||||
Parser = check_parse() # Initializing parser
|
||||
|
||||
|
||||
def compute_metrics(references, generated, lang) -> dict:
|
||||
"""
|
||||
Calculates various metrics and returns the calculated dict of these matrics.
|
||||
args:
|
||||
reference: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
lang(str) : The language generated code belongs to
|
||||
returns:
|
||||
A dicitonary with different metrics intact.
|
||||
"""
|
||||
metrics_dict = {} # Update as in new metrics are added over here.
|
||||
metrics_dict["smoothed_bleu_4"] = compute_bleu(references, generated, smooth=True)
|
||||
metrics_dict["bleu_4"] = compute_bleu(references, generated, smooth=False)
|
||||
metrics_dict["parse_score"] = Parser(generated, lang)["parse_score"]
|
||||
return metrics_dict
|
53
evaluation/metrics/parse_check.py
Normal file
53
evaluation/metrics/parse_check.py
Normal file
@ -0,0 +1,53 @@
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
def load_tree_sitter_languages():
|
||||
"""Loads language Grammars to evaluate"""
|
||||
py_parser = Parser()
|
||||
py_parser.set_language(Language('./tree_sitter_utils/build/my-languages.so', 'python'))
|
||||
js_parser = Parser()
|
||||
js_parser.set_language(Language('./tree_sitter_utils/build/my-languages.so', 'javascript'))
|
||||
cpp_parser = Parser()
|
||||
cpp_parser.set_language(Language('./tree_sitter_utils/build/my-languages.so', 'cpp'))
|
||||
go_parser = Parser()
|
||||
go_parser.set_language(Language('./tree_sitter_utils/build/my-languages.so', 'go'))
|
||||
java_parser = Parser()
|
||||
java_parser.set_language(Language('./tree_sitter_utils/build/my-languages.so', 'java'))
|
||||
return {
|
||||
"py" : py_parser,
|
||||
"js" : js_parser,
|
||||
"cpp" : cpp_parser,
|
||||
"go" : go_parser,
|
||||
"java": java_parser
|
||||
}
|
||||
|
||||
class check_parse:
|
||||
def __init__(self):
|
||||
self.language_dict = load_tree_sitter_languages()
|
||||
def __call__(self,batch,lang):
|
||||
"""
|
||||
args:
|
||||
batch : list[str] of code generated by the model
|
||||
lang : lang should be one of the above language_dict keys
|
||||
|
||||
returns:
|
||||
dict(
|
||||
parse_score = averaged out score on how many datapoints are parsed
|
||||
index_parse = check if corresponding index is parsed
|
||||
)
|
||||
"""
|
||||
cumulative_parse_score = 0
|
||||
index_parse_list = []
|
||||
parser = self.language_dict[lang]
|
||||
for inp in batch:
|
||||
parsed = parser.parse(bytes(inp,"utf-8"))
|
||||
inp_ind_score = int("ERROR" not in parsed.root_node.sexp())
|
||||
cumulative_parse_score+=inp_ind_score
|
||||
index_parse_list.append(inp_ind_score)
|
||||
return {"parse_score":cumulative_parse_score,"index_parse":index_parse_list}
|
||||
if __name__ == "__main__":
|
||||
Parse = check_parse()
|
||||
score = Parse(["""
|
||||
def a():
|
||||
if bar:
|
||||
baz()"""],"py")
|
||||
print(score)
|
BIN
evaluation/metrics/tree_sitter_utils/.DS_Store
vendored
Normal file
BIN
evaluation/metrics/tree_sitter_utils/.DS_Store
vendored
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,837 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6d0d0fdf-def5-4c93-b1e9-4223d70c3c22",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!/usr/bin/env python\n",
|
||||
"# coding=utf-8\n",
|
||||
"# Copyright 2021 The HuggingFace Team All rights reserved.\n",
|
||||
"#\n",
|
||||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"\"\"\"\n",
|
||||
"Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n",
|
||||
"\n",
|
||||
"Here is the full list of checkpoints on the hub that can be fine-tuned by this script:\n",
|
||||
"https://huggingface.co/models?filter=causal-lm\n",
|
||||
"\"\"\"\n",
|
||||
"# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"import math\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import time\n",
|
||||
"from dataclasses import dataclass, field\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Callable, Optional\n",
|
||||
"\n",
|
||||
"import datasets\n",
|
||||
"from datasets import Dataset, load_dataset\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"import optax\n",
|
||||
"import transformers\n",
|
||||
"from flax import jax_utils, traverse_util\n",
|
||||
"from flax.jax_utils import unreplicate\n",
|
||||
"from flax.training import train_state\n",
|
||||
"from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key\n",
|
||||
"from flax.serialization import to_bytes, from_bytes\n",
|
||||
"from transformers import (\n",
|
||||
" CONFIG_MAPPING,\n",
|
||||
" FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n",
|
||||
" AutoConfig,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" FlaxAutoModelForCausalLM,\n",
|
||||
" HfArgumentParser,\n",
|
||||
" TrainingArguments,\n",
|
||||
" is_tensorboard_available,\n",
|
||||
")\n",
|
||||
"from transformers.testing_utils import CaptureLogger\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"\n",
|
||||
"MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\n",
|
||||
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class ModelArguments:\n",
|
||||
" \"\"\"\n",
|
||||
" Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" model_name_or_path: Optional[str] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"The model checkpoint for weights initialization.\"\n",
|
||||
" \"Don't set if you want to train a model from scratch.\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" model_type: Optional[str] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n",
|
||||
" )\n",
|
||||
" config_name: Optional[str] = field(\n",
|
||||
" default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n",
|
||||
" )\n",
|
||||
" tokenizer_name: Optional[str] = field(\n",
|
||||
" default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n",
|
||||
" )\n",
|
||||
" cache_dir: Optional[str] = field(\n",
|
||||
" default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n",
|
||||
" )\n",
|
||||
" use_fast_tokenizer: bool = field(\n",
|
||||
" default=True,\n",
|
||||
" metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n",
|
||||
" )\n",
|
||||
" dtype: Optional[str] = field(\n",
|
||||
" default=\"float32\",\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`.\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class DataTrainingArguments:\n",
|
||||
" \"\"\"\n",
|
||||
" Arguments pertaining to what data we are going to input our model for training and eval.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" dataset_name: Optional[str] = field(\n",
|
||||
" default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n",
|
||||
" )\n",
|
||||
" dataset_config_name: Optional[str] = field(\n",
|
||||
" default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n",
|
||||
" )\n",
|
||||
" train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n",
|
||||
" validation_file: Optional[str] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n",
|
||||
" )\n",
|
||||
" max_train_samples: Optional[int] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n",
|
||||
" \"value if set.\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" max_eval_samples: Optional[int] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n",
|
||||
" \"value if set.\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" overwrite_cache: bool = field(\n",
|
||||
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
|
||||
" )\n",
|
||||
" validation_split_percentage: Optional[int] = field(\n",
|
||||
" default=5,\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" block_size: Optional[int] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\n",
|
||||
" \"help\": \"Optional input sequence length after tokenization. \"\n",
|
||||
" \"The training dataset will be truncated in block of this size for training. \"\n",
|
||||
" \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" overwrite_cache: bool = field(\n",
|
||||
" default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n",
|
||||
" )\n",
|
||||
" preprocessing_num_workers: Optional[int] = field(\n",
|
||||
" default=None,\n",
|
||||
" metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n",
|
||||
" )\n",
|
||||
" text_column_name: Optional[str] = field(\n",
|
||||
" default='text',\n",
|
||||
" metadata={\"help\": \"Column containing main text data.\"},\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def __post_init__(self):\n",
|
||||
" if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n",
|
||||
" raise ValueError(\"Need either a dataset name or a training/validation file.\")\n",
|
||||
" else:\n",
|
||||
" if self.train_file is not None:\n",
|
||||
" extension = self.train_file.split(\".\")[-1]\n",
|
||||
" assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n",
|
||||
" if self.validation_file is not None:\n",
|
||||
" extension = self.validation_file.split(\".\")[-1]\n",
|
||||
" assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TrainState(train_state.TrainState):\n",
|
||||
" dropout_rng: jnp.ndarray\n",
|
||||
"\n",
|
||||
" def replicate(self):\n",
|
||||
" return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
|
||||
" \"\"\"\n",
|
||||
" Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
|
||||
" Shuffle batches if `shuffle` is `True`.\n",
|
||||
" \"\"\"\n",
|
||||
" steps_per_epoch = len(dataset) // batch_size\n",
|
||||
"\n",
|
||||
" if shuffle:\n",
|
||||
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
|
||||
" else:\n",
|
||||
" batch_idx = jnp.arange(len(dataset))\n",
|
||||
"\n",
|
||||
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
|
||||
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
|
||||
"\n",
|
||||
" for idx in batch_idx:\n",
|
||||
" batch = dataset[idx]\n",
|
||||
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
|
||||
"\n",
|
||||
" batch = shard(batch)\n",
|
||||
"\n",
|
||||
" yield batch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def write_train_metric(summary_writer, train_metrics, train_time, step):\n",
|
||||
" summary_writer.scalar(\"train_time\", train_time, step)\n",
|
||||
"\n",
|
||||
" train_metrics = get_metrics(train_metrics)\n",
|
||||
" for key, vals in train_metrics.items():\n",
|
||||
" tag = f\"train_{key}\"\n",
|
||||
" for i, val in enumerate(vals):\n",
|
||||
" summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def write_eval_metric(summary_writer, eval_metrics, step):\n",
|
||||
" for metric_name, value in eval_metrics.items():\n",
|
||||
" summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_learning_rate_fn(\n",
|
||||
" train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n",
|
||||
") -> Callable[[int], jnp.array]:\n",
|
||||
" \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n",
|
||||
" steps_per_epoch = train_ds_size // train_batch_size\n",
|
||||
" num_train_steps = steps_per_epoch * num_train_epochs\n",
|
||||
" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n",
|
||||
" decay_fn = optax.linear_schedule(\n",
|
||||
" init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n",
|
||||
" )\n",
|
||||
" schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n",
|
||||
" return schedule_fn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b87fb52e-7e4b-4a69-8c63-fe739331c1c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_args = ModelArguments(\n",
|
||||
" model_name_or_path=\"EleutherAI/gpt-neo-125M\",\n",
|
||||
" dtype=\"bfloat16\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "75c30ff3-94ec-437f-af79-17ad8429d7eb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_args = DataTrainingArguments(\n",
|
||||
" dataset_name=\"code_search_net\", \n",
|
||||
" dataset_config_name=\"python\", \n",
|
||||
" block_size=1024,\n",
|
||||
" max_train_samples=10000, \n",
|
||||
" max_eval_samples=1000, \n",
|
||||
" preprocessing_num_workers=8\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e618fc54-07d1-4e6c-bb48-7e050044a0c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"bs = 16\n",
|
||||
"training_args = TrainingArguments(\n",
|
||||
" num_train_epochs=1,\n",
|
||||
" output_dir=\"./tmp\", \n",
|
||||
" per_device_train_batch_size=bs, \n",
|
||||
" per_device_eval_batch_size=bs, \n",
|
||||
" learning_rate=3e-4,\n",
|
||||
" weight_decay=0.1,\n",
|
||||
" do_train=True,\n",
|
||||
" do_eval=True,\n",
|
||||
" warmup_steps=100,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" overwrite_output_dir=True,\n",
|
||||
" report_to=None\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "20645921-b444-4e8d-b582-08af685b42d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def save_checkpoint(model, save_dir, state):\n",
|
||||
" state = jax_utils.unreplicate(state)\n",
|
||||
" print(f\"SAVING CHECKPOINT IN {save_dir}\", end=\" ... \")\n",
|
||||
" model.save_pretrained(\n",
|
||||
" training_args.output_dir,\n",
|
||||
" params=state.params,\n",
|
||||
" push_to_hub=training_args.push_to_hub,\n",
|
||||
" commit_message=f\"Saving weights and logs of epoch {epoch+1}\",\n",
|
||||
" )\n",
|
||||
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"wb\") as f:\n",
|
||||
" f.write(to_bytes(state.opt_state))\n",
|
||||
" with open(os.path.join(save_dir, \"training_state.json\"), \"w\") as f:\n",
|
||||
" json.dump({\"step\": state.step.item()}, f)\n",
|
||||
" print(\"DONE\")\n",
|
||||
" \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d11fd82-435d-4afd-b1eb-9914cd465a18",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def restore_checkpoint(save_dir, state):\n",
|
||||
" print(f\"RESTORING CHECKPOINT FROM {save_dir}\", end=\" ... \")\n",
|
||||
" with open(os.path.join(save_dir, \"flax_model.msgpack\"), \"rb\") as f:\n",
|
||||
" params = from_bytes(state.params, f.read())\n",
|
||||
"\n",
|
||||
" with open(os.path.join(save_dir, \"opt_state.msgpack\"), \"rb\") as f:\n",
|
||||
" opt_state = from_bytes(state.opt_state, f.read())\n",
|
||||
"\n",
|
||||
" with open(os.path.join(save_dir, \"training_state.json\"), \"r\") as f:\n",
|
||||
" training_state = json.load(f)\n",
|
||||
" step = training_state[\"step\"]\n",
|
||||
"\n",
|
||||
" print(\"DONE\")\n",
|
||||
" return params, opt_state, step"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "562e0fd9-83b5-4d51-9232-df97fca4f063",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:absl:A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
|
||||
" 10000,\n",
|
||||
" 128,\n",
|
||||
" training_args.num_train_epochs,\n",
|
||||
" training_args.warmup_steps,\n",
|
||||
" training_args.learning_rate,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def decay_mask_fn(params):\n",
|
||||
" flat_params = traverse_util.flatten_dict(params)\n",
|
||||
" flat_mask = {\n",
|
||||
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
|
||||
" for path in flat_params\n",
|
||||
" }\n",
|
||||
" return traverse_util.unflatten_dict(flat_mask)\n",
|
||||
"\n",
|
||||
"optimizer = optax.adamw(\n",
|
||||
" learning_rate=linear_decay_lr_schedule_fn,\n",
|
||||
" b1=training_args.adam_beta1,\n",
|
||||
" b2=training_args.adam_beta2,\n",
|
||||
" eps=training_args.adam_epsilon,\n",
|
||||
" weight_decay=training_args.weight_decay,\n",
|
||||
" mask=decay_mask_fn,\n",
|
||||
" )\n",
|
||||
"model = FlaxAutoModelForCausalLM.from_pretrained(training_args.output_dir)\n",
|
||||
"rng = jax.random.PRNGKey(training_args.seed)\n",
|
||||
"rng, dropout_rng = jax.random.split(rng)\n",
|
||||
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "56d2d8be-3516-4f86-b0e5-20b3a9079163",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"RESTORING CHECKPOINT FROM ./tmp ... DONE\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"params, opt_state, step = restore_checkpoint(training_args.output_dir, state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0db5eca6-f081-4c3e-a7f1-3e080da8cc89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if (\n",
|
||||
" os.path.exists(training_args.output_dir)\n",
|
||||
" and os.listdir(training_args.output_dir)\n",
|
||||
" and training_args.do_train\n",
|
||||
" and not training_args.overwrite_output_dir\n",
|
||||
"):\n",
|
||||
" raise ValueError(\n",
|
||||
" f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n",
|
||||
" \"Use --overwrite_output_dir to overcome.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Make one log on every process with the configuration for debugging.\n",
|
||||
"logging.basicConfig(\n",
|
||||
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
|
||||
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
|
||||
" level=logging.INFO,\n",
|
||||
")\n",
|
||||
"# Setup logging, we only want one process per machine to log things on the screen.\n",
|
||||
"logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n",
|
||||
"if jax.process_index() == 0:\n",
|
||||
" datasets.utils.logging.set_verbosity_warning()\n",
|
||||
" transformers.utils.logging.set_verbosity_info()\n",
|
||||
"else:\n",
|
||||
" datasets.utils.logging.set_verbosity_error()\n",
|
||||
" transformers.utils.logging.set_verbosity_error()\n",
|
||||
"\n",
|
||||
"# Set the verbosity to info of the Transformers logger (on main process only):\n",
|
||||
"logger.info(f\"Training/evaluation parameters {training_args}\")\n",
|
||||
"\n",
|
||||
"# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n",
|
||||
"# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n",
|
||||
"# (the dataset will be downloaded automatically from the datasets Hub).\n",
|
||||
"#\n",
|
||||
"# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n",
|
||||
"# 'text' is found. You can easily tweak this behavior (see below).\n",
|
||||
"#\n",
|
||||
"# In distributed training, the load_dataset function guarantees that only one local process can concurrently\n",
|
||||
"# download the dataset.\n",
|
||||
"if data_args.dataset_name is not None:\n",
|
||||
" # Downloading and loading a dataset from the hub.\n",
|
||||
" dataset = load_dataset(\n",
|
||||
" data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if \"validation\" not in dataset.keys():\n",
|
||||
" dataset[\"validation\"] = load_dataset(\n",
|
||||
" data_args.dataset_name,\n",
|
||||
" data_args.dataset_config_name,\n",
|
||||
" split=f\"train[:{data_args.validation_split_percentage}%]\",\n",
|
||||
" cache_dir=model_args.cache_dir,\n",
|
||||
" )\n",
|
||||
" dataset[\"train\"] = load_dataset(\n",
|
||||
" data_args.dataset_name,\n",
|
||||
" data_args.dataset_config_name,\n",
|
||||
" split=f\"train[{data_args.validation_split_percentage}%:]\",\n",
|
||||
" cache_dir=model_args.cache_dir,\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" data_files = {}\n",
|
||||
" if data_args.train_file is not None:\n",
|
||||
" data_files[\"train\"] = data_args.train_file\n",
|
||||
" if data_args.validation_file is not None:\n",
|
||||
" data_files[\"validation\"] = data_args.validation_file\n",
|
||||
" extension = data_args.train_file.split(\".\")[-1]\n",
|
||||
" if extension == \"txt\":\n",
|
||||
" extension = \"text\"\n",
|
||||
" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)\n",
|
||||
"# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n",
|
||||
"# https://huggingface.co/docs/datasets/loading_datasets.html.\n",
|
||||
"\n",
|
||||
"# Load pretrained model and tokenizer\n",
|
||||
"\n",
|
||||
"# Distributed training:\n",
|
||||
"# The .from_pretrained methods guarantee that only one local process can concurrently\n",
|
||||
"# download model & vocab.\n",
|
||||
"if model_args.config_name:\n",
|
||||
" config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)\n",
|
||||
"elif model_args.model_name_or_path:\n",
|
||||
" config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)\n",
|
||||
"else:\n",
|
||||
" config = CONFIG_MAPPING[model_args.model_type]()\n",
|
||||
" logger.warning(\"You are instantiating a new config instance from scratch.\")\n",
|
||||
"\n",
|
||||
"if model_args.tokenizer_name:\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(\n",
|
||||
" model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
|
||||
" )\n",
|
||||
"elif model_args.model_name_or_path:\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(\n",
|
||||
" model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" raise ValueError(\n",
|
||||
" \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n",
|
||||
" \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"if model_args.model_name_or_path:\n",
|
||||
" model = FlaxAutoModelForCausalLM.from_pretrained(\n",
|
||||
" model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" model = FlaxAutoModelForCausalLM.from_config(\n",
|
||||
" config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Preprocessing the datasets.\n",
|
||||
"# First we tokenize all the texts.\n",
|
||||
"if training_args.do_train:\n",
|
||||
" column_names = dataset[\"train\"].column_names\n",
|
||||
"else:\n",
|
||||
" column_names = dataset[\"validation\"].column_names\n",
|
||||
"text_column_name = data_args.text_column_name if data_args.text_column_name in column_names else column_names[0]\n",
|
||||
"\n",
|
||||
"# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n",
|
||||
"tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n",
|
||||
"\n",
|
||||
"def tokenize_function(examples):\n",
|
||||
" with CaptureLogger(tok_logger) as cl:\n",
|
||||
" output = tokenizer(examples[text_column_name])\n",
|
||||
" # clm input could be much much longer than block_size\n",
|
||||
" if \"Token indices sequence length is longer than the\" in cl.out:\n",
|
||||
" tok_logger.warning(\n",
|
||||
" \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\"\n",
|
||||
" )\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
"tokenized_datasets = dataset.map(\n",
|
||||
" tokenize_function,\n",
|
||||
" batched=True,\n",
|
||||
" num_proc=data_args.preprocessing_num_workers,\n",
|
||||
" remove_columns=column_names,\n",
|
||||
" load_from_cache_file=not data_args.overwrite_cache,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if data_args.block_size is None:\n",
|
||||
" block_size = tokenizer.model_max_length\n",
|
||||
" if block_size > config.max_position_embeddings:\n",
|
||||
" logger.warning(\n",
|
||||
" f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n",
|
||||
" \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n",
|
||||
" )\n",
|
||||
" block_size = 1024\n",
|
||||
"else:\n",
|
||||
" if data_args.block_size > tokenizer.model_max_length:\n",
|
||||
" logger.warning(\n",
|
||||
" f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n",
|
||||
" f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n",
|
||||
" )\n",
|
||||
" block_size = min(data_args.block_size, tokenizer.model_max_length)\n",
|
||||
"\n",
|
||||
"# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n",
|
||||
"def group_texts(examples):\n",
|
||||
" # Concatenate all texts.\n",
|
||||
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
||||
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
||||
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
||||
" # customize this part to your needs.\n",
|
||||
" total_length = (total_length // block_size) * block_size\n",
|
||||
" # Split by chunks of max_len.\n",
|
||||
" result = {\n",
|
||||
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
||||
" for k, t in concatenated_examples.items()\n",
|
||||
" }\n",
|
||||
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n",
|
||||
"# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n",
|
||||
"# to preprocess.\n",
|
||||
"#\n",
|
||||
"# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
|
||||
"# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
|
||||
"\n",
|
||||
"lm_datasets = tokenized_datasets.map(\n",
|
||||
" group_texts,\n",
|
||||
" batched=True,\n",
|
||||
" num_proc=data_args.preprocessing_num_workers,\n",
|
||||
" load_from_cache_file=not data_args.overwrite_cache,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if training_args.do_train:\n",
|
||||
" if \"train\" not in tokenized_datasets:\n",
|
||||
" raise ValueError(\"--do_train requires a train dataset\")\n",
|
||||
" train_dataset = lm_datasets[\"train\"]\n",
|
||||
" if data_args.max_train_samples is not None:\n",
|
||||
" train_dataset = train_dataset.select(range(data_args.max_train_samples))\n",
|
||||
"\n",
|
||||
"if training_args.do_eval:\n",
|
||||
" if \"validation\" not in tokenized_datasets:\n",
|
||||
" raise ValueError(\"--do_eval requires a validation dataset\")\n",
|
||||
" eval_dataset = lm_datasets[\"validation\"]\n",
|
||||
" if data_args.max_eval_samples is not None:\n",
|
||||
" eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))\n",
|
||||
"\n",
|
||||
"# Enable tensorboard only on the master node\n",
|
||||
"has_tensorboard = is_tensorboard_available()\n",
|
||||
"if has_tensorboard and jax.process_index() == 0:\n",
|
||||
" try:\n",
|
||||
" from flax.metrics.tensorboard import SummaryWriter\n",
|
||||
"\n",
|
||||
" summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n",
|
||||
" except ImportError as ie:\n",
|
||||
" has_tensorboard = False\n",
|
||||
" logger.warning(\n",
|
||||
" f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n",
|
||||
" )\n",
|
||||
"else:\n",
|
||||
" logger.warning(\n",
|
||||
" \"Unable to display metrics through TensorBoard because the package is not installed: \"\n",
|
||||
" \"Please run pip install tensorboard to enable.\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Initialize our training\n",
|
||||
"rng = jax.random.PRNGKey(training_args.seed)\n",
|
||||
"rng, dropout_rng = jax.random.split(rng)\n",
|
||||
"\n",
|
||||
"# Store some constant\n",
|
||||
"num_epochs = int(training_args.num_train_epochs)\n",
|
||||
"train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()\n",
|
||||
"eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()\n",
|
||||
"steps_per_epoch = len(train_dataset) // train_batch_size\n",
|
||||
"total_train_steps = steps_per_epoch * num_epochs\n",
|
||||
"\n",
|
||||
"# Create learning rate schedule\n",
|
||||
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
|
||||
" len(train_dataset),\n",
|
||||
" train_batch_size,\n",
|
||||
" training_args.num_train_epochs,\n",
|
||||
" training_args.warmup_steps,\n",
|
||||
" training_args.learning_rate,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# We use Optax's \"masking\" functionality to not apply weight decay\n",
|
||||
"# to bias and LayerNorm scale parameters. decay_mask_fn returns a\n",
|
||||
"# mask boolean with the same structure as the parameters.\n",
|
||||
"# The mask is True for parameters that should be decayed.\n",
|
||||
"# Note that this mask is specifically adapted for FlaxGPT2.\n",
|
||||
"# For other models, one should correct the layer norm parameter naming\n",
|
||||
"# accordingly.\n",
|
||||
"def decay_mask_fn(params):\n",
|
||||
" flat_params = traverse_util.flatten_dict(params)\n",
|
||||
" flat_mask = {\n",
|
||||
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
|
||||
" for path in flat_params\n",
|
||||
" }\n",
|
||||
" return traverse_util.unflatten_dict(flat_mask)\n",
|
||||
"\n",
|
||||
"# create adam optimizer\n",
|
||||
"adamw = optax.adamw(\n",
|
||||
" learning_rate=linear_decay_lr_schedule_fn,\n",
|
||||
" b1=training_args.adam_beta1,\n",
|
||||
" b2=training_args.adam_beta2,\n",
|
||||
" eps=training_args.adam_epsilon,\n",
|
||||
" weight_decay=training_args.weight_decay,\n",
|
||||
" mask=decay_mask_fn,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Setup train state\n",
|
||||
"state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)\n",
|
||||
"\n",
|
||||
"def loss_fn(logits, labels):\n",
|
||||
" shift_logits = logits[..., :-1, :]\n",
|
||||
" shift_labels = labels[..., 1:]\n",
|
||||
" loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))\n",
|
||||
" return loss.mean()\n",
|
||||
"\n",
|
||||
"# Define gradient update step fn\n",
|
||||
"def train_step(state, batch):\n",
|
||||
" dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)\n",
|
||||
"\n",
|
||||
" def compute_loss(params):\n",
|
||||
" labels = batch.pop(\"labels\")\n",
|
||||
" logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n",
|
||||
" loss = loss_fn(logits, labels)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" grad_fn = jax.value_and_grad(compute_loss)\n",
|
||||
" loss, grad = grad_fn(state.params)\n",
|
||||
" grad = jax.lax.pmean(grad, \"batch\")\n",
|
||||
"\n",
|
||||
" new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)\n",
|
||||
"\n",
|
||||
" metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n",
|
||||
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
|
||||
"\n",
|
||||
" return new_state, metrics\n",
|
||||
"\n",
|
||||
"# Define eval fn\n",
|
||||
"def eval_step(params, batch):\n",
|
||||
" labels = batch.pop(\"labels\")\n",
|
||||
" logits = model(**batch, params=params, train=False)[0]\n",
|
||||
" loss = loss_fn(logits, labels)\n",
|
||||
"\n",
|
||||
" # summarize metrics\n",
|
||||
" metrics = {\"loss\": loss}\n",
|
||||
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
|
||||
" return metrics\n",
|
||||
"\n",
|
||||
"# Create parallel version of the train and eval step\n",
|
||||
"p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n",
|
||||
"p_eval_step = jax.pmap(eval_step, \"batch\")\n",
|
||||
"\n",
|
||||
"# Replicate the train state on each device\n",
|
||||
"state = state.replicate()\n",
|
||||
"\n",
|
||||
"logger.info(\"***** Running training *****\")\n",
|
||||
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
|
||||
"logger.info(f\" Num Epochs = {num_epochs}\")\n",
|
||||
"logger.info(f\" Instantaneous batch size per device = {training_args.per_device_train_batch_size}\")\n",
|
||||
"logger.info(f\" Total train batch size (w. parallel & distributed) = {train_batch_size}\")\n",
|
||||
"logger.info(f\" Total optimization steps = {total_train_steps}\")\n",
|
||||
"\n",
|
||||
"train_time = 0\n",
|
||||
"train_metrics = []\n",
|
||||
"epochs = tqdm(range(num_epochs), desc=f\"Epoch ... (1/{num_epochs})\", position=0)\n",
|
||||
"for epoch in epochs:\n",
|
||||
" # ======================== Training ================================\n",
|
||||
" train_start = time.time()\n",
|
||||
"\n",
|
||||
" # Create sampling rng\n",
|
||||
" rng, input_rng = jax.random.split(rng)\n",
|
||||
"\n",
|
||||
" # Generate an epoch by shuffling sampling indices from the train dataset\n",
|
||||
" train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)\n",
|
||||
" steps_per_epoch = len(train_dataset) // train_batch_size\n",
|
||||
" # train\n",
|
||||
" for step in tqdm(range(steps_per_epoch), desc=\"Training...\", position=1, leave=False):\n",
|
||||
" batch = next(train_loader)\n",
|
||||
" state, train_metric = p_train_step(state, batch)\n",
|
||||
" train_metrics.append(train_metric)\n",
|
||||
"\n",
|
||||
" cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n",
|
||||
"\n",
|
||||
" if cur_step % training_args.logging_steps == 0 and cur_step > 0:\n",
|
||||
" # Save metrics\n",
|
||||
" train_metric = unreplicate(train_metric)\n",
|
||||
" train_time += time.time() - train_start\n",
|
||||
" if has_tensorboard and jax.process_index() == 0:\n",
|
||||
" write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n",
|
||||
"\n",
|
||||
" epochs.write(\n",
|
||||
" f\"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" train_metrics = []\n",
|
||||
"\n",
|
||||
" # ======================== Evaluating ==============================\n",
|
||||
" eval_metrics = []\n",
|
||||
" eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)\n",
|
||||
" eval_steps = len(eval_dataset) // eval_batch_size\n",
|
||||
" for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n",
|
||||
" # Model forward\n",
|
||||
" batch = next(eval_loader)\n",
|
||||
" metrics = p_eval_step(state.params, batch)\n",
|
||||
" eval_metrics.append(metrics)\n",
|
||||
"\n",
|
||||
" # normalize eval metrics\n",
|
||||
" eval_metrics = get_metrics(eval_metrics)\n",
|
||||
"\n",
|
||||
" eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n",
|
||||
" except OverflowError:\n",
|
||||
" eval_metrics[\"perplexity\"] = float(\"inf\")\n",
|
||||
"\n",
|
||||
" # Print metrics and update progress bar\n",
|
||||
" desc = f\"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})\"\n",
|
||||
" epochs.write(desc)\n",
|
||||
" epochs.desc = desc\n",
|
||||
"\n",
|
||||
" # Save metrics\n",
|
||||
" if has_tensorboard and jax.process_index() == 0:\n",
|
||||
" cur_step = epoch * (len(train_dataset) // train_batch_size)\n",
|
||||
" write_eval_metric(summary_writer, eval_metrics, cur_step)\n",
|
||||
"\n",
|
||||
" # save checkpoint after each epoch and push checkpoint to the hub\n",
|
||||
" if jax.process_index() == 0:\n",
|
||||
" save_checkpoint(model, training_args.output_dir, state)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# if __name__ == \"__main__\":\n",
|
||||
"# main()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f39555f2-ccf7-458c-8756-be1ff330229b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_checkpoint(model, training_args.output_dir, state)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,91 +0,0 @@
|
||||
import requests
|
||||
import sys
|
||||
|
||||
# Insert GitHub API token here, in place of *TOKEN*.
|
||||
headers = {"Authorization": "token *TOKEN*"}
|
||||
|
||||
# Constants & language argument.
|
||||
NUM_REPOS = 10_000
|
||||
MIN_STARS = 50
|
||||
LANGUAGE = "java" if len(sys.argv) <= 1 else sys.argv[1]
|
||||
|
||||
|
||||
def main():
|
||||
repositories = set() # Use a set to avoid duplicate entries across pages.
|
||||
max_stars = 1_000_000_000 # Initialize at a very high value.
|
||||
while len(repositories) < NUM_REPOS:
|
||||
new_repositories = run_query(max_stars)
|
||||
max_stars = min([stars for _, stars in new_repositories])
|
||||
# If a query returns no new repositories, drop it.
|
||||
if len(repositories | new_repositories) == len(repositories):
|
||||
break
|
||||
repositories.update(new_repositories)
|
||||
print(f'Collected {len(repositories):,} repositories so far; lowest number of stars: {max_stars:,}')
|
||||
|
||||
with open(f'{LANGUAGE}-top-repos.txt', 'w') as f:
|
||||
for repository, _ in sorted(repositories, key=lambda e: e[1], reverse=True):
|
||||
f.write(f'{repository}\n')
|
||||
|
||||
|
||||
def run_query(max_stars):
|
||||
end_cursor = None # Used to track pagination.
|
||||
repositories = set()
|
||||
|
||||
while end_cursor != "":
|
||||
# Extracts non-fork, recently active repositories in the provided language, in groups of 100.
|
||||
# Leaves placeholders for maximum stars and page cursor. The former allows us to retrieve more than 1,000 repositories
|
||||
# by repeatedly lowering the bar.
|
||||
query = f"""
|
||||
{{
|
||||
search(query: "language:{LANGUAGE} fork:false pushed:>2020-01-01 sort:stars stars:<{max_stars}", type: REPOSITORY, first: 100 {', after: "' + end_cursor + '"' if end_cursor else ''}) {{
|
||||
edges {{
|
||||
node {{
|
||||
... on Repository {{
|
||||
url
|
||||
isPrivate
|
||||
isDisabled
|
||||
isLocked
|
||||
stargazers {{
|
||||
totalCount
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
pageInfo {{
|
||||
hasNextPage
|
||||
endCursor
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
print(f' Retrieving next page; {len(repositories)} repositories in this batch so far.')
|
||||
request = requests.post('https://api.github.com/graphql', json={'query': query}, headers=headers)
|
||||
content = request.json()
|
||||
end_cursor = get_end_cursor(content)
|
||||
repositories.update(get_repositories(content))
|
||||
if len(repositories) > NUM_REPOS:
|
||||
break
|
||||
return repositories
|
||||
|
||||
|
||||
def get_end_cursor(content):
|
||||
page_info = content['data']['search']['pageInfo']
|
||||
has_next_page = page_info['hasNextPage']
|
||||
if has_next_page:
|
||||
return page_info['endCursor']
|
||||
return ""
|
||||
|
||||
|
||||
def get_repositories(content):
|
||||
edges = content['data']['search']['edges']
|
||||
repositories_with_stars = []
|
||||
for edge in edges:
|
||||
if edge['node']['isPrivate'] is False and edge['node']['isDisabled'] is False and edge['node']['isLocked'] is False:
|
||||
repository = edge['node']['url']
|
||||
star_count = edge['node']['stargazers']['totalCount']
|
||||
repositories_with_stars.append((repository, star_count))
|
||||
return repositories_with_stars
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,10 +0,0 @@
|
||||
in_file=$1
|
||||
language=$2
|
||||
cat $in_file | xargs -P16 -n1 -I% bash -c 'echo %; \
|
||||
name=$(echo % | cut -d"/" -f2); \
|
||||
org=$(echo % | cut -d"/" -f1); \
|
||||
echo "Cloning $org/$name"
|
||||
DIR=Repos/'$language'/$org; \
|
||||
mkdir -p $DIR; \
|
||||
echo $DIR; \
|
||||
git clone -q --depth 1 https://github.com/$org/$name $DIR/$name'
|
@ -1,50 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import pygments
|
||||
from pygments.lexers import get_lexer_by_name
|
||||
from pygments.token import Token
|
||||
|
||||
def main():
|
||||
if len(sys.argv) <= 3:
|
||||
raise ValueError('Provide at least an input directory, an output directory, and a language.')
|
||||
input_directory = sys.argv[1]
|
||||
output_directory = sys.argv[2]
|
||||
language = sys.argv[3]
|
||||
if input_directory.endswith('/'):
|
||||
input_directory = input_directory[:-1]
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
|
||||
lexer = get_lexer_by_name(language)
|
||||
language_extensions = set(ext.lower()[1:] for ext in lexer.filenames)
|
||||
|
||||
with open(os.path.join(output_directory, os.path.basename(input_directory) + '.txt'), 'w') as f_out:
|
||||
for root, _, files in os.walk(input_directory):
|
||||
for name in files:
|
||||
ext = os.path.splitext(name)[1].lower()
|
||||
if ext in language_extensions:
|
||||
print(f'Lexing: {root}, {name}')
|
||||
lex_file(os.path.join(root, name), f_out, lexer)
|
||||
|
||||
|
||||
def lex_file(file_path, f_out, lexer):
|
||||
with open(file_path, errors='ignore') as f_in:
|
||||
text = f_in.read()
|
||||
|
||||
lexed = []
|
||||
for (ttype, token) in pygments.lex(text, lexer):
|
||||
if ttype in Token.Text:
|
||||
continue
|
||||
elif ttype in Token.Comment:
|
||||
continue
|
||||
else:
|
||||
lexed.append(token.replace('\t', '#TAB#'))
|
||||
|
||||
# Skip empty files.
|
||||
if not lexed:
|
||||
return
|
||||
f_out.write('\t'.join(lexed))
|
||||
f_out.write('\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,7 +0,0 @@
|
||||
# Github Scraping Scripts
|
||||
|
||||
This directory contains scripts to scrape github (based on scripts originally by [Vincent Hellendoorn](https://vhellendoorn.github.io/)).
|
||||
|
||||
* `01_metadata_by_github_api.py`: A script to crawl repo metadata using the github API. Currently it crawls N repositories of down to M github stars for language L.
|
||||
* `02_clone_repos.sh`: Actually download the repos.
|
||||
* `03_lexer.py`: Finds and lexes code by file extension. Note that this does some stuff that might be bad for model training such as getting rid of whitespace, so it might need to be fixed.
|
@ -1,487 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "84b1a438-cf1d-402e-a56f-2c4f9dd5ad51",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, FlaxGPTNeoForCausalLM, AutoModelForMaskedLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7d50cd18-33ed-4b67-82ad-5c48eb9a9b36",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"# model_ckpt = 'EleutherAI/gpt-neo-125M'\n",
|
||||
"model_ckpt = (Path.home()/'gpt-neo-125M-code-clippy').as_posix()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ec0c4cc-a1bc-4dda-bd0b-72891b519b39",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "065c03c3-2e4a-4f20-a30d-25ada1418b18",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:absl:Starting the local TPU driver.\n",
|
||||
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
||||
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
||||
"model = FlaxGPTNeoForCausalLM.from_pretrained(model_ckpt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "e2f9fb26-2e26-4f57-aa93-e349475203f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer.pad_token = tokenizer.eos_token"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "75c0c2f6-47ad-41c3-8c66-a1ceeecde061",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = \"\"\"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"x = np.random.randn(10, 10)\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"id": "666977a1-de0d-4900-bf61-ae2b672e51bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = tokenizer(prompt, return_tensors='jax')\n",
|
||||
"input_ids = inputs.input_ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"id": "249e4a8a-7a7e-4e8b-83be-7184a4c0dd0b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(1, 40, 50257)"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"outputs = model(**inputs)\n",
|
||||
"outputs.logits.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "eee873f5-073c-4cbe-8b15-114ea18b2de8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[ 198, 11748, 299, 32152, 355, 45941, 198, 11748,\n",
|
||||
" 19798, 292, 355, 279, 67, 198, 11748, 2603,\n",
|
||||
" 29487, 8019, 13, 9078, 29487, 355, 458, 83,\n",
|
||||
" 198, 198, 87, 796, 45941, 13, 25120, 13,\n",
|
||||
" 25192, 77, 7, 940, 11, 838, 8, 198]], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"input_ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "82666225-3ab7-405f-9536-4e9e3085be24",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = model.generate(input_ids,\n",
|
||||
" max_length=200, \n",
|
||||
" num_beams=1,\n",
|
||||
" pad_token_id = tokenizer.pad_token_id\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "c6cc862b-23ef-417d-ae83-1b2eafb0460f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"FlaxGreedySearchOutput(sequences=DeviceArray([[ 198, 11748, 299, 32152, 355, 45941, 198, 11748,\n",
|
||||
" 19798, 292, 355, 279, 67, 198, 11748, 2603,\n",
|
||||
" 29487, 8019, 13, 9078, 29487, 355, 458, 83,\n",
|
||||
" 198, 198, 87, 796, 45941, 13, 25120, 13,\n",
|
||||
" 25192, 77, 7, 940, 11, 838, 8, 198,\n",
|
||||
" 88, 796, 45941, 13, 25120, 13, 25192, 77,\n",
|
||||
" 7, 940, 11, 838, 8, 198, 198, 2,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220,\n",
|
||||
" 220, 220, 220, 220, 220, 220, 220, 220]], dtype=int32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "8f6c746a-2d56-4da4-acb5-e066a6a230f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"x = np.random.randn(10, 10)\n",
|
||||
"y = np.random.randn(10, 10)\n",
|
||||
"\n",
|
||||
"# \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(tokenizer.decode(out[0][0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "b6effeaa-2237-47bc-b0f6-940c4e274c38",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import GPTNeoForCausalLM, AutoModelForCausalLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "3665a3fd-5d92-45e8-8fde-393ec803383a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
|
||||
" pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n",
|
||||
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
|
||||
"\n",
|
||||
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['lm_head.weight', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.4.attn.attention.bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = GPTNeoForCausalLM.from_pretrained(model_ckpt, from_flax=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fa23301-6f5d-40d5-b614-f14330df894a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transormers import AutoModelForMaskedLM\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(model_ckpt, from_flax=True)\n",
|
||||
"model.save_pretrained(model_ckpt, save_config=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"id": "35114633-bb5f-4c00-ae16-540a7fabb126",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
|
||||
"\n",
|
||||
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['lm_head.weight', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.4.attn.attention.bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = AutoModelForCausalLM.from_pretrained(model_ckpt, from_flax=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"id": "15cd3853-1308-46e1-90c1-52b3af0fcac4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = \"\"\"\n",
|
||||
"my_list = ['banana', 'apple', 'orange', 'pineapple']\n",
|
||||
"\n",
|
||||
"#Using brute force method\n",
|
||||
"last_element = my_list[len(my_list) - 1]\n",
|
||||
"\n",
|
||||
"#Using negative indeces\n",
|
||||
"last_element = my_list[-1]\n",
|
||||
"\n",
|
||||
"#Using pop method\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"id": "2f2fc69a-f8f5-4859-bb2e-5c33e63f064a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = \"\"\"\n",
|
||||
"def get_vowels(string):\n",
|
||||
" return [vowel for vowel in string if vowel in 'aeiou'] \n",
|
||||
"\n",
|
||||
"print(\"Vowels are:\",\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"id": "517aa451-3316-45fc-97ab-1a9a52ba55b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = \"\"\"import time\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"\n",
|
||||
"total = 0\n",
|
||||
"for i in range(10):\n",
|
||||
" total += i\n",
|
||||
"print(\"Sum:\", total)\n",
|
||||
"\n",
|
||||
"end_time = time.time()\n",
|
||||
"time_taken = \"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"id": "749f6c3d-e1a4-4df7-be81-086024345766",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = tokenizer(prompt, return_tensors='pt')\n",
|
||||
"input_ids = inputs.input_ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"id": "6f60e3f0-d051-4df1-8258-7bc479486603",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = model.generate(input_ids,\n",
|
||||
" max_length=200, \n",
|
||||
" num_beams=1,\n",
|
||||
" pad_token_id = tokenizer.pad_token_id\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"id": "9d17aec3-e42a-43d6-a535-2eeaad2a9c78",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"start_time = time.time()\n",
|
||||
"\n",
|
||||
"total = 0\n",
|
||||
"for i in range(10):\n",
|
||||
" total += i\n",
|
||||
"print(\"Sum:\", total)\n",
|
||||
"\n",
|
||||
"end_time = time.time()\n",
|
||||
"time_taken = time.time()\n",
|
||||
"\n",
|
||||
"# \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(tokenizer.decode(out[0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"id": "57574549-bd1d-46b0-98ca-352662f735d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save_pretrained(model_ckpt, save_config=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0578148d-497d-422f-b7fb-b644d2a7c62f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2021-07-06 15:02:08.590730: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
||||
"2021-07-06 15:02:08.590769: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "63ffd8ff-8e95-4aad-9068-b27fd8c129bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import fields"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "a9a89020-e7b0-4826-88e6-8ac4f4c6f89e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Field(name='skip_memory_metrics',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object at 0x7f9a12926af0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether or not to skip adding of memory profiler reports to metrics.'}),_field_type=_FIELD)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for f in fields(TrainingArguments):\n",
|
||||
" if f.name == \"skip_memory_metrics\":\n",
|
||||
" print(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "beda87b7-c461-4f92-8988-4255a8e79cf9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
125
metrics/bleu.py
125
metrics/bleu.py
@ -1,125 +0,0 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Took the following from CodeXGlue Repository - https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/evaluator/CodeBLEU/bleu.py
|
||||
|
||||
|
||||
"""Python implementation of BLEU and smooth-BLEU.
|
||||
|
||||
This module provides a Python implementation of BLEU and smooth-BLEU.
|
||||
Smooth BLEU is computed following the method outlined in the paper:
|
||||
Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
|
||||
evaluation metrics for machine translation. COLING 2004.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i:i+order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
|
||||
smooth=True):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus,
|
||||
translation_corpus):
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram)-1] += overlap[ngram]
|
||||
for order in range(1, max_order+1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order-1] += possible_matches
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = ((matches_by_order[i] + 1.) /
|
||||
(possible_matches_by_order[i] + 1.))
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (float(matches_by_order[i]) /
|
||||
possible_matches_by_order[i])
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.
|
||||
else:
|
||||
bp = math.exp(1 - 1. / ratio)
|
||||
bleu = geo_mean * bp
|
||||
print(geo_mean)
|
||||
bleu_score_dict = {"bleu":bleu,"precision":precisions,"bp":bp,"ratio":ratio,"trans_len":translation_length,"ref_len":reference_length}
|
||||
return bleu_score_dict#(bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||
|
||||
def bleu_test_case():
|
||||
"""A simple functionality test case to evaluate BLEU"""
|
||||
generated = [[["a","=","b","\n","y","=","a","+","1"]]]
|
||||
reference = [["a","=","b","\n","print","a"]]
|
||||
score_dict = compute_bleu(generated,reference,smooth=False)
|
||||
return score_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
score_dict = bleu_test_case()
|
||||
print(score_dict)
|
@ -1,18 +0,0 @@
|
||||
from metrics.bleu import compute_bleu
|
||||
|
||||
def compute_metrics(references,generated) -> dict:
|
||||
"""
|
||||
Calculates various metrics and returns the calculated dict of these matrics.
|
||||
args:
|
||||
reference: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
returns:
|
||||
A dicitonary with different metrics intact.
|
||||
"""
|
||||
metrics_dict = {} #Update as in new metrics are added over here.
|
||||
metrics_dict["smoothed_bleu_4"] = compute_bleu(references,generated,smooth=True)
|
||||
metrics_dict["bleu_4"] = compute_bleu(references,generated,smooth=False)
|
||||
|
||||
return metrics_dict
|
@ -1,491 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gdown\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from pathlib import Path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"File exists: ../data/repo_infos.csv\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'../data/repo_infos.csv'"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 5
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data_path = Path(\"../data/\")\n",
|
||||
"github_downloader_path = Path(\"../dependency_repos/github-downloader/\")\n",
|
||||
"\n",
|
||||
"gdown.cached_download(\n",
|
||||
" \"https://drive.google.com/uc?id=1T-eBxIZ4S8n6UiI8jdzpsmpk9KgZt793\",\n",
|
||||
" str(data_path/\"repo_infos.csv\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"our_repos = pd.read_csv(data_path/\"repo_infos.csv\", parse_dates=True)\n",
|
||||
"eleuther_repos = pd.read_csv(\n",
|
||||
" github_downloader_path/\"github_repositories.csv\", names=[\"name\", \"stargazers\", \"languages\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([ 28043.1 , 70708.05, 338060.77])"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 21
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"np.percentile(our_repos[\"size\"].values, [90, 95, 99])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(544340, 517122)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 22
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(our_repos), len(our_repos[our_repos[\"size\"] < 70708])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"our_filtered_repos = our_repos[our_repos[\"size\"] < 70708]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"648023"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 24
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Combine our repos and EleutherAI's repos\n",
|
||||
"def combine_repos(ours, eleuthers):\n",
|
||||
" # Combine our repos\n",
|
||||
" combined = pd.concat(\n",
|
||||
" [ours, eleuthers],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Remove duplicate repos\n",
|
||||
" dedup_combined = combined[~combined[\"name\"].duplicated(keep=\"last\")]\n",
|
||||
" return dedup_combined\n",
|
||||
"\n",
|
||||
"combined = combine_repos(our_filtered_repos, eleuther_repos)\n",
|
||||
"len(combined)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"combined.to_csv(data_path/\"combined_repos_size_filtered.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"/home/nathan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3441: DtypeWarning: Columns (1,3,4,6,7,12,13,14,15,16,17,18,19,20,21,22,23,24,26) have mixed types.Specify dtype option on import or set low_memory=False.\n exec(code_obj, self.user_global_ns, self.user_ns)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"combined = pd.read_csv(data_path/\"combined_repos_size_filtered.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"Shard the dataset into manageable pieces since EleutherAI's downloader has a memory leak."
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"N_SHARDS = 24\n",
|
||||
"\n",
|
||||
"shards = np.array_split(combined, N_SHARDS)\n",
|
||||
"lens = list(map(len, shards))\n",
|
||||
"assert sum(lens) == len(combined)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for idx, shard in enumerate(shards):\n",
|
||||
" shard.to_csv(data_path/f\"shards/combined_repos_shard_{idx}\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
" name fork project commits branches \\\n",
|
||||
"0 0-1-0/lightblue-0.4 False 8.0 1 \n",
|
||||
"1 0-14n/ndroid False 131.0 1 \n",
|
||||
"3 0-sec/zero-crack False 4.0 1 \n",
|
||||
"4 0-tikaro/minimum-viable-startpage False 15.0 1 \n",
|
||||
"5 0-u-0/dugon-media-server False 52.0 1 \n",
|
||||
"\n",
|
||||
" default branch releases contributors license \\\n",
|
||||
"0 master 0.0 4 GNU General Public License v3.0 \n",
|
||||
"1 master 0.0 2 Other \n",
|
||||
"3 main 1.0 1 GNU General Public License v3.0 \n",
|
||||
"4 master 0.0 ? MIT License \n",
|
||||
"5 master 5.0 ? MIT License \n",
|
||||
"\n",
|
||||
" watchers stargazers ... total issues open issues total pull requests \\\n",
|
||||
"0 14.0 86 ... 9 8 5 \n",
|
||||
"1 5.0 50 ... 1 1 2 \n",
|
||||
"3 0.0 62 ... 0 0 0 \n",
|
||||
"4 4.0 56 ... 0 0 2 \n",
|
||||
"5 2.0 14 ... 5 1 0 \n",
|
||||
"\n",
|
||||
" open pull requests last commit \\\n",
|
||||
"0 0 2020-10-18 21:26:07.0 \n",
|
||||
"1 1 2015-03-17 13:10:07.0 \n",
|
||||
"3 0 2021-05-12 02:03:08.0 \n",
|
||||
"4 1 2019-04-21 09:11:12.0 \n",
|
||||
"5 0 2020-05-16 04:11:45.0 \n",
|
||||
"\n",
|
||||
" last commit SHA has wiki is archived \\\n",
|
||||
"0 9a4f7b37e923b262d2a29894676ff8ed8cde6237 True False \n",
|
||||
"1 4e5dbe69855a7fda8b74e61d9db5aa61e6ba9ee8 True False \n",
|
||||
"3 70ee16550a81b396333565515723d5abab87c719 True False \n",
|
||||
"4 a4fb4aea4474d635c4e4738f7d8c1a485d5d74c8 True False \n",
|
||||
"5 1d6bb1c589e51d2c34b11be20d34dae4bb0c7779 True False \n",
|
||||
"\n",
|
||||
" languages \\\n",
|
||||
"0 NaN \n",
|
||||
"1 C,C++,Objective-C,Shell,Assembly,Haxe,Groff,Py... \n",
|
||||
"3 Python \n",
|
||||
"4 JavaScript,CSS,HTML \n",
|
||||
"5 JavaScript,Dockerfile \n",
|
||||
"\n",
|
||||
" labels \n",
|
||||
"0 NaN \n",
|
||||
"1 bug,duplicate,enhancement,help wanted,invalid,... \n",
|
||||
"3 bug,documentation,duplicate,enhancement,good f... \n",
|
||||
"4 bug,duplicate,enhancement,good first issue,hel... \n",
|
||||
"5 bug,documentation,duplicate,enhancement,featur... \n",
|
||||
"\n",
|
||||
"[5 rows x 27 columns]"
|
||||
],
|
||||
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>name</th>\n <th>fork project</th>\n <th>commits</th>\n <th>branches</th>\n <th>default branch</th>\n <th>releases</th>\n <th>contributors</th>\n <th>license</th>\n <th>watchers</th>\n <th>stargazers</th>\n <th>...</th>\n <th>total issues</th>\n <th>open issues</th>\n <th>total pull requests</th>\n <th>open pull requests</th>\n <th>last commit</th>\n <th>last commit SHA</th>\n <th>has wiki</th>\n <th>is archived</th>\n <th>languages</th>\n <th>labels</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0-1-0/lightblue-0.4</td>\n <td>False</td>\n <td>8.0</td>\n <td>1</td>\n <td>master</td>\n <td>0.0</td>\n <td>4</td>\n <td>GNU General Public License v3.0</td>\n <td>14.0</td>\n <td>86</td>\n <td>...</td>\n <td>9</td>\n <td>8</td>\n <td>5</td>\n <td>0</td>\n <td>2020-10-18 21:26:07.0</td>\n <td>9a4f7b37e923b262d2a29894676ff8ed8cde6237</td>\n <td>True</td>\n <td>False</td>\n <td>NaN</td>\n <td>NaN</td>\n </tr>\n <tr>\n <th>1</th>\n <td>0-14n/ndroid</td>\n <td>False</td>\n <td>131.0</td>\n <td>1</td>\n <td>master</td>\n <td>0.0</td>\n <td>2</td>\n <td>Other</td>\n <td>5.0</td>\n <td>50</td>\n <td>...</td>\n <td>1</td>\n <td>1</td>\n <td>2</td>\n <td>1</td>\n <td>2015-03-17 13:10:07.0</td>\n <td>4e5dbe69855a7fda8b74e61d9db5aa61e6ba9ee8</td>\n <td>True</td>\n <td>False</td>\n <td>C,C++,Objective-C,Shell,Assembly,Haxe,Groff,Py...</td>\n <td>bug,duplicate,enhancement,help wanted,invalid,...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>0-sec/zero-crack</td>\n <td>False</td>\n <td>4.0</td>\n <td>1</td>\n <td>main</td>\n <td>1.0</td>\n <td>1</td>\n <td>GNU General Public License v3.0</td>\n <td>0.0</td>\n <td>62</td>\n <td>...</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>0</td>\n <td>2021-05-12 02:03:08.0</td>\n <td>70ee16550a81b396333565515723d5abab87c719</td>\n <td>True</td>\n <td>False</td>\n <td>Python</td>\n <td>bug,documentation,duplicate,enhancement,good f...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>0-tikaro/minimum-viable-startpage</td>\n <td>False</td>\n <td>15.0</td>\n <td>1</td>\n <td>master</td>\n <td>0.0</td>\n <td>?</td>\n <td>MIT License</td>\n <td>4.0</td>\n <td>56</td>\n <td>...</td>\n <td>0</td>\n <td>0</td>\n <td>2</td>\n <td>1</td>\n <td>2019-04-21 09:11:12.0</td>\n <td>a4fb4aea4474d635c4e4738f7d8c1a485d5d74c8</td>\n <td>True</td>\n <td>False</td>\n <td>JavaScript,CSS,HTML</td>\n <td>bug,duplicate,enhancement,good first issue,hel...</td>\n </tr>\n <tr>\n <th>5</th>\n <td>0-u-0/dugon-media-server</td>\n <td>False</td>\n <td>52.0</td>\n <td>1</td>\n <td>master</td>\n <td>5.0</td>\n <td>?</td>\n <td>MIT License</td>\n <td>2.0</td>\n <td>14</td>\n <td>...</td>\n <td>5</td>\n <td>1</td>\n <td>0</td>\n <td>0</td>\n <td>2020-05-16 04:11:45.0</td>\n <td>1d6bb1c589e51d2c34b11be20d34dae4bb0c7779</td>\n <td>True</td>\n <td>False</td>\n <td>JavaScript,Dockerfile</td>\n <td>bug,documentation,duplicate,enhancement,featur...</td>\n </tr>\n </tbody>\n</table>\n<p>5 rows × 27 columns</p>\n</div>"
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 45
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"combined.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def convert_to_gh_downloader(df, path):\n",
|
||||
" # select only name, stargazers, and languages columns and save to path\n",
|
||||
" df = df[['name', 'stargazers', 'languages']]\n",
|
||||
" df.to_csv(path/\"combined_github_repositories.csv\", index=False, header=False)\n",
|
||||
"\n",
|
||||
"convert_to_gh_downloader(combined.head(200_000), github_downloader_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once you've converted the combined data to the proper format, change directory to the `github-downloader` repository, install the requirements `pip install python-magic fire tqdm requests lm-dataformat` (the `requirements.txt` doesn't work) and decompression tool `sudo apt install zstd`, edit the file `download_repo_text.py` so that it reads the file `combined_github_repositories.csv` rather than `github_repositories.csv`, and run that file. This will generate a bunch of `json.zst` files containing the text of all the files in the repositories and their metadata (the repo and file it came from). Next decompress the files `zstd -d <file_path>` so that you can read them using pandas."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text_df = pd.read_json(github_downloader_path/\"github_data/data_0_time1625430673_default.jsonl\", orient=\"records\",lines=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>meta</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>{\\n \"name\": \"dugon-media-server\",\\n \"version...</td>\n",
|
||||
" <td>{'repo_name': '0-u-0/dugon-media-server', 'sta...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>\\nconst DEFAULT_CONFIG_PATH = './config';\\ncon...</td>\n",
|
||||
" <td>{'repo_name': '0-u-0/dugon-media-server', 'sta...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>module.exports = {\\n debug: true,\\n id: \"foo...</td>\n",
|
||||
" <td>{'repo_name': '0-u-0/dugon-media-server', 'sta...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td># Dugon Signal Server (Javascript)\\n\\n[![GitHu...</td>\n",
|
||||
" <td>{'repo_name': '0-u-0/dugon-media-server', 'sta...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>FROM node:carbon\\n\\nRUN \\\\n\\tset -x \\\\n\\t&& ap...</td>\n",
|
||||
" <td>{'repo_name': '0-u-0/dugon-media-server', 'sta...</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" text \\\n",
|
||||
"0 {\\n \"name\": \"dugon-media-server\",\\n \"version... \n",
|
||||
"1 \\nconst DEFAULT_CONFIG_PATH = './config';\\ncon... \n",
|
||||
"2 module.exports = {\\n debug: true,\\n id: \"foo... \n",
|
||||
"3 # Dugon Signal Server (Javascript)\\n\\n[![GitHu... \n",
|
||||
"4 FROM node:carbon\\n\\nRUN \\\\n\\tset -x \\\\n\\t&& ap... \n",
|
||||
"\n",
|
||||
" meta \n",
|
||||
"0 {'repo_name': '0-u-0/dugon-media-server', 'sta... \n",
|
||||
"1 {'repo_name': '0-u-0/dugon-media-server', 'sta... \n",
|
||||
"2 {'repo_name': '0-u-0/dugon-media-server', 'sta... \n",
|
||||
"3 {'repo_name': '0-u-0/dugon-media-server', 'sta... \n",
|
||||
"4 {'repo_name': '0-u-0/dugon-media-server', 'sta... "
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"text_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(text_df.text.values[100])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"2045"
|
||||
]
|
||||
},
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(text_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\"level\":\"warn\",\"ts\":1625538732.958198,\"caller\":\"roundtripper/rate_limit.go:58\",\"msg\":\"Rate limit exceeded. Waiting 49m9.041814722s to retry...\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_391514/951878845.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m results = check_output(\n\u001b[0m\u001b[1;32m 7\u001b[0m [\n\u001b[1;32m 8\u001b[0m \u001b[0mscorecard_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/lib/python3.8/subprocess.py\u001b[0m in \u001b[0;36mcheck_output\u001b[0;34m(timeout, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 413\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mempty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 415\u001b[0;31m return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,\n\u001b[0m\u001b[1;32m 416\u001b[0m **kwargs).stdout\n\u001b[1;32m 417\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/lib/python3.8/subprocess.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mPopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mpopenargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 495\u001b[0;31m \u001b[0mstdout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstderr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommunicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 496\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTimeoutExpired\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkill\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/lib/python3.8/subprocess.py\u001b[0m in \u001b[0;36mcommunicate\u001b[0;34m(self, input, timeout)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stdin_write\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstdout\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1015\u001b[0;31m \u001b[0mstdout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstdout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstdout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstderr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from subprocess import CalledProcessError, check_output\n",
|
||||
"\n",
|
||||
"scorecard_path = Path(\"../scripts/scorecard\")\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" results = check_output(\n",
|
||||
" [\n",
|
||||
" scorecard_path,\n",
|
||||
" f\"--repo=github.com/kubernetes/kubernetes\",\n",
|
||||
" \"--checks=Vulnerabilities\",\n",
|
||||
" \"--format=csv\",\n",
|
||||
" ]\n",
|
||||
" ).decode(\"utf-8\")\n",
|
||||
"except CalledProcessError as e:\n",
|
||||
" # Exception thrown when the method is malformed, i.e, it is missing a curly brace\n",
|
||||
" error = e.output.decode(\"utf-8\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.8.10 64-bit"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
#! /bin/bash
|
||||
./run_clm_flax.py \
|
||||
--output_dir $HOME/tmp/gpt-neo-125M-code-clippy \
|
||||
--model_name_or_path="EleutherAI/gpt-neo-125M" \
|
||||
--dataset_name="code_search_net" \
|
||||
--dataset_config_name="python" \
|
||||
--text_column_name="func_code_string" \
|
||||
--do_train --do_eval \
|
||||
--block_size="2048" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="16" \
|
||||
--preprocessing_num_workers="8" \
|
||||
--learning_rate="6e-4" \
|
||||
--adafactor \
|
||||
--warmup_steps="100" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--weight_decay="0.01" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="10" \
|
||||
--logging_steps="100" \
|
||||
--eval_steps="200" \
|
||||
--push_to_hub="False" \
|
||||
--report_to="none" \
|
||||
--dtype="bfloat16" \
|
||||
--skip_memory_metrics="False" \
|
||||
--save_steps="200" \
|
||||
--save_total_limit 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
|
||||
# --max_train_samples="10000" \
|
||||
# --max_eval_samples="1000"
|
@ -1,37 +0,0 @@
|
||||
#! /bin/bash
|
||||
./run_clm_streaming_flax_v2.py \
|
||||
--output_dir $HOME/gpt-neo-125M-test \
|
||||
--model_name_or_path="EleutherAI/gpt-neo-125M" \
|
||||
--dataset_name $HOME/gpt-code-clippy/code_clippy.py \
|
||||
--data_dir /home/shared/code-clippy-dataset/merged-data \
|
||||
--text_column_name="text" \
|
||||
--do_train --do_eval \
|
||||
--block_size="2048" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="16" \
|
||||
--preprocessing_num_workers="8" \
|
||||
--learning_rate="6e-4" \
|
||||
--max_steps 500 \
|
||||
--warmup_steps 150 \
|
||||
--decay_steps 250 \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.95" \
|
||||
--weight_decay="0.01" \
|
||||
--overwrite_output_dir \
|
||||
--logging_steps="10" \
|
||||
--eval_steps="50" \
|
||||
--push_to_hub="True" \
|
||||
--report_to="all" \
|
||||
--dtype="bfloat16" \
|
||||
--skip_memory_metrics="False" \
|
||||
--save_steps="50" \
|
||||
--save_total_limit 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--report_to="wandb" \
|
||||
--run_name="testing-mini" \
|
||||
--max_eval_samples 100 \
|
||||
--save_optimizer true \
|
||||
# --adafactor \
|
||||
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ \
|
||||
# --max_train_samples="10000" \
|
||||
|
@ -1,37 +0,0 @@
|
||||
#! /bin/bash
|
||||
./run_clm_streaming_flax_clean.py \
|
||||
--output_dir $HOME/tmp/gpt-neo-125M-test \
|
||||
--model_name_or_path="EleutherAI/gpt-neo-125M" \
|
||||
--dataset_name="wikitext" \
|
||||
--dataset_config_name="wikitext-103-raw-v1" \
|
||||
--text_column_name="text" \
|
||||
--do_train --do_eval \
|
||||
--block_size="128" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="8" \
|
||||
--preprocessing_num_workers="8" \
|
||||
--learning_rate="6e-4" \
|
||||
--max_steps 500 \
|
||||
--warmup_steps 150 \
|
||||
--decay_steps 250 \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.95" \
|
||||
--weight_decay="0.1" \
|
||||
--overwrite_output_dir \
|
||||
--logging_steps="10" \
|
||||
--eval_steps="50" \
|
||||
--push_to_hub="False" \
|
||||
--report_to="all" \
|
||||
--dtype="bfloat16" \
|
||||
--skip_memory_metrics="False" \
|
||||
--save_steps="50" \
|
||||
--save_total_limit 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--report_to="none" \
|
||||
--run_name="testing-mini" \
|
||||
--max_eval_samples 100 \
|
||||
--save_optimizer true \
|
||||
# --resume_from_checkpoint $HOME/gpt-neo-125M-test/ckpt-800 \
|
||||
# --adafactor \
|
||||
# --max_train_samples="10000" \
|
||||
|
@ -1,35 +0,0 @@
|
||||
#! /bin/bash
|
||||
./run_clm_flax.py \
|
||||
--output_dir $HOME/tmp/gpt-neo-125M-test-3 \
|
||||
--model_name_or_path="EleutherAI/gpt-neo-125M" \
|
||||
--dataset_name="wikitext" \
|
||||
--dataset_config_name="wikitext-2-raw-v1" \
|
||||
--text_column_name="text" \
|
||||
--do_train --do_eval \
|
||||
--block_size="128" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="16" \
|
||||
--preprocessing_num_workers="8" \
|
||||
--learning_rate="2e-5" \
|
||||
--warmup_steps="100" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--weight_decay="0.01" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="10" \
|
||||
--logging_steps="10" \
|
||||
--eval_steps="10" \
|
||||
--push_to_hub="False" \
|
||||
--report_to="none" \
|
||||
--run_name="test-non-streaming" \
|
||||
--dtype="bfloat16" \
|
||||
--skip_memory_metrics="False" \
|
||||
--save_steps="20" \
|
||||
--save_strategy steps \
|
||||
--save_total_limit 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--save_optimizer true \
|
||||
--resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-3/ckpt-640 \
|
||||
# --adafactor \
|
||||
# --max_train_samples="10000" \
|
||||
# --max_eval_samples="1000"
|
@ -1,45 +0,0 @@
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from fastcore.script import *
|
||||
from ghapi.all import GhApi
|
||||
|
||||
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
|
||||
ISSUE_TITLE = "Participation in an Open Source Language Modeling Dataset"
|
||||
ISSUE_BODY = """Hi there, your repository has been selected to be included in an effort
|
||||
to train an open source version of GitHub and OpenAI's [Copilot tool](https://copilot.github.com/).
|
||||
You can find more information on our project [here](https://github.com/ncoop57/gpt-code-clippy).
|
||||
|
||||
If you are the owner/admin of this repository and would like to opt-out of this,
|
||||
please downvote this issue before July 9th and we will remove your repository
|
||||
from our list. We will comment an acknowledgement in this issue if you choose
|
||||
to opt-out on July 9th. If you have any questions, please open an issue on our
|
||||
[repository](https://github.com/ncoop57/gpt-code-clippy/issues/new).
|
||||
"""
|
||||
|
||||
REPOS = [
|
||||
"ProgrammingSpace/test-repo-3",
|
||||
"ncoop57/deep_parking",
|
||||
"ncoop57/test-repo",
|
||||
"ncoop57/recipe_name_suggester",
|
||||
"ncoop57/test-repo-2",
|
||||
]
|
||||
|
||||
# Open issue on repo using custom title and body
|
||||
def open_issue(owner, repo):
|
||||
api = GhApi(owner=owner, repo=repo, token=GITHUB_TOKEN)
|
||||
api.issues.create(title=ISSUE_TITLE, body=ISSUE_BODY)
|
||||
|
||||
|
||||
@call_parse
|
||||
def main(repos_path: Param("Path to the csv containing all of the repos", str)):
|
||||
"""
|
||||
Use pandas dataframe from the repos path to open issues in each of them.
|
||||
"""
|
||||
df = pd.read_csv(repos_path)
|
||||
|
||||
# Loop through repos and open issue for each repo
|
||||
for _, row in df.iterrows():
|
||||
owner, repo = row["name"].split("/")
|
||||
open_issue(owner=owner, repo=repo)
|
@ -1,27 +0,0 @@
|
||||
import os
|
||||
|
||||
from ghapi.all import GhApi
|
||||
|
||||
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
|
||||
ISSUE_TITLE = "Participation in an Open Source Language Modeling Dataset"
|
||||
REPOS = [
|
||||
"ProgrammingSpace/test-repo-3",
|
||||
"ncoop57/deep_parking",
|
||||
"ncoop57/test-repo",
|
||||
"ncoop57/recipe_name_suggester",
|
||||
"ncoop57/test-repo-2",
|
||||
]
|
||||
|
||||
for r in REPOS:
|
||||
print(r)
|
||||
owner, repo = r.split("/")
|
||||
api = GhApi(owner=owner, repo=repo, token=GITHUB_TOKEN)
|
||||
issues = api.issues.list_for_repo(owner=owner, repo=repo, state="all")
|
||||
for i in issues:
|
||||
if i.title == ISSUE_TITLE:
|
||||
if i.reactions["-1"] > 0:
|
||||
print(f"{r} is opting out")
|
||||
api.issues.create_comment(
|
||||
i.number, body="Thank you, your repository has been removed."
|
||||
)
|
||||
api.issues.update(i.number, state="closed")
|
@ -290,9 +290,6 @@ def fake_update(state):
|
||||
def reinstantiate_states(opt_state):
|
||||
new_state = []
|
||||
for state in opt_state:
|
||||
if isinstance(state, list):
|
||||
new_state.append(reinstantiate_states(state))
|
||||
else:
|
||||
cls = getattr(optax, type(state).__name__)
|
||||
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
||||
return new_state
|
Loading…
Reference in New Issue
Block a user