Merge pull request #62 from HiroakiMikami/fix-eval

Fix HumanEval evaluation script.
This commit is contained in:
Nathan Cooper 2021-10-09 08:20:54 -04:00 committed by GitHub
commit 1711eb3123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,7 +39,7 @@ def clean_text(generation):
return generation
def generate_text(prompt, n, tokenizer, model):
def generate_text(prompt, n, tokenizer, model, include_prompt=True):
inputs = tokenizer(prompt, truncation=True, max_length=MAX_TOKS, return_tensors="pt").to("cuda")
output_seq = model.generate(
input_ids=inputs.input_ids, max_length=MAX_TOKS,
@ -52,7 +52,7 @@ def generate_text(prompt, n, tokenizer, model):
generated_text = []
for o in outputs:
cleaned = clean_text(o.replace(prompt, ""))
generated_text.append(prompt + cleaned)
generated_text.append(prompt + cleaned if include_prompt else cleaned)
return generated_text
@ -92,7 +92,8 @@ def _eval_human_eval(path, out_path, tokenizer, model):
problems[task_id]["prompt"],
num_samples_per_task,
tokenizer,
model
model,
include_prompt=False,
):
samples.append(dict(task_id=task_id, completion=text))