gpt-code-clippy/code-clippy-app.ipynb
2021-07-19 00:49:15 +00:00

257 lines
9.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "333a4f54-e120-4969-8adf-32b98655ff41",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-18 23:45:31.083087: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2021-07-18 23:45:31.083131: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
}
],
"source": [
"import gradio as gr\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d875c6bc-8d97-4e03-9ee4-a4bd47f191bf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n",
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
"\n",
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['transformer.h.4.attn.attention.bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.10.attn.attention.bias', 'lm_head.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model_name = \"/home/shared/models/gpt-code-clippy-125M-apps-lr-adam1e-4-bs128/ckpt-1633/\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, from_flax=True)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n",
"tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e6ebb017-e784-4311-a795-4bacd2263d19",
"metadata": {},
"outputs": [],
"source": [
"def format_input(question, starter_code=\"\"):\n",
" answer_type = \"\\nUse Call-Based format\\n\" if starter_code else \"\\nUse Standard Input format\\n\"\n",
" return f\"\\nQUESTION:\\n{question}\\n{starter_code}\\n{answer_type}\\nANSWER:\\n\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b00ad3ab-3086-48f0-940a-3129a7dff30a",
"metadata": {},
"outputs": [],
"source": [
"def format_outputs(text):\n",
" formatted_text =f'''\n",
" <head>\n",
" <link rel=\"stylesheet\"\n",
" href=\"https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.0.3/styles/default.min.css\">\n",
" <script src=\"https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.0.3/highlight.min.js\"></script>\n",
" <script>hljs.initHighlightingOnLoad();</script>\n",
" </head>\n",
" <body>\n",
" <pre><code class=\"python\">{text}</code></pre>\n",
" </body>\n",
" '''\n",
" return formatted_text"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "28abdab5-962e-47af-83b2-4f9c491ba705",
"metadata": {},
"outputs": [],
"source": [
"def generate_solution(question, starter_code=\"\", temperature=1., num_beams=1):\n",
" prompt = format_input(question, starter_code)\n",
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
" start = len(input_ids[0])\n",
" output = model.generate(\n",
" input_ids,\n",
" max_length=start+200,\n",
" do_sample=True,\n",
" top_p=0.95,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" early_stopping=True,\n",
" temperature=1.,\n",
" num_beams=int(num_beams),\n",
" no_repeat_ngram_size=None,\n",
" repetition_penalty=None,\n",
" num_return_sequences=None,\n",
" )\n",
" \n",
" return format_outputs(tokenizer.decode(output[0][start:]).strip())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a156ff76-7e50-40c9-8b3e-e9ddeb450501",
"metadata": {},
"outputs": [],
"source": [
"_EXAMPLES = [\n",
" [\n",
" \"\"\"\n",
"Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row.\n",
"For Example:\n",
"```python\n",
"[\n",
" [1, 2, 3, 4, 5], # minimum value of row is 1\n",
" [5, 6, 7, 8, 9], # minimum value of row is 5\n",
" [20, 21, 34, 56, 100] # minimum value of row is 20\n",
"]\n",
"```\n",
"So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26`\n",
" \"\"\",\n",
" \"\",\n",
" 0.8,\n",
" ],\n",
" [\n",
" \"\"\"\n",
"# Personalized greeting\n",
"\n",
"Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`.\n",
" \"\"\",\n",
" \"\"\"\n",
"Use conditionals to return the proper message:\n",
"\n",
"case| return\n",
"--- | ---\n",
"name equals owner | 'Hello boss'\n",
"otherwise | 'Hello guest'\n",
"def greet(name, owner):\n",
" \"\"\",\n",
" 0.8,\n",
" ]\n",
"] "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fb3d90fc-6932-4343-9c86-70ae94ca95aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running locally at: http://127.0.0.1:7861/\n",
"This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n",
"Running on External URL: https://34711.gradio.app\n",
"Interface loading below...\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"900\"\n",
" height=\"500\"\n",
" src=\"https://34711.gradio.app\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7f0bf02db670>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<Flask 'gradio.networking'>,\n",
" 'http://127.0.0.1:7861/',\n",
" 'https://34711.gradio.app')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = [\n",
" gr.inputs.Textbox(placeholder=\"Define a problem here...\", lines=7),\n",
" gr.inputs.Textbox(placeholder=\"Provide optional starter code...\", lines=3),\n",
" gr.inputs.Slider(0.5, 1.5, 0.1, default=0.8, label=\"Temperature\"),\n",
" gr.inputs.Slider(1,4,1,default=1, label=\"Beam size\")\n",
"]\n",
"\n",
"outputs = [\n",
" gr.outputs.HTML(label=\"Solution\")\n",
"]\n",
"\n",
"gr.Interface(\n",
" generate_solution, \n",
" inputs=inputs, \n",
" outputs=outputs,\n",
" title=\"Code Clippy: Problem Solver\",\n",
" examples=_EXAMPLES,\n",
").launch(share=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "737cd94c-5286-4832-9611-06e6f2a89357",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}