mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-09-19 03:19:21 +03:00
257 lines
9.0 KiB
Plaintext
257 lines
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "333a4f54-e120-4969-8adf-32b98655ff41",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2021-07-18 23:45:31.083087: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
|
"2021-07-18 23:45:31.083131: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import gradio as gr\n",
|
|
"\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d875c6bc-8d97-4e03-9ee4-a4bd47f191bf",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/arto/transformers/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
|
|
" pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)\n",
|
|
"All Flax model weights were used when initializing GPTNeoForCausalLM.\n",
|
|
"\n",
|
|
"Some weights of GPTNeoForCausalLM were not initialized from the Flax model and are newly initialized: ['transformer.h.4.attn.attention.bias', 'transformer.h.8.attn.attention.masked_bias', 'transformer.h.0.attn.attention.bias', 'transformer.h.2.attn.attention.masked_bias', 'transformer.h.2.attn.attention.bias', 'transformer.h.8.attn.attention.bias', 'transformer.h.10.attn.attention.masked_bias', 'transformer.h.11.attn.attention.masked_bias', 'transformer.h.3.attn.attention.masked_bias', 'transformer.h.9.attn.attention.masked_bias', 'transformer.h.6.attn.attention.masked_bias', 'transformer.h.7.attn.attention.masked_bias', 'transformer.h.6.attn.attention.bias', 'transformer.h.0.attn.attention.masked_bias', 'transformer.h.1.attn.attention.masked_bias', 'transformer.h.5.attn.attention.masked_bias', 'transformer.h.4.attn.attention.masked_bias', 'transformer.h.10.attn.attention.bias', 'lm_head.weight']\n",
|
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model_name = \"/home/shared/models/gpt-code-clippy-125M-apps-lr-adam1e-4-bs128/ckpt-1633/\"\n",
|
|
"model = AutoModelForCausalLM.from_pretrained(model_name, from_flax=True)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-125M\")\n",
|
|
"tokenizer.pad_token = tokenizer.eos_token"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "e6ebb017-e784-4311-a795-4bacd2263d19",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def format_input(question, starter_code=\"\"):\n",
|
|
" answer_type = \"\\nUse Call-Based format\\n\" if starter_code else \"\\nUse Standard Input format\\n\"\n",
|
|
" return f\"\\nQUESTION:\\n{question}\\n{starter_code}\\n{answer_type}\\nANSWER:\\n\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "b00ad3ab-3086-48f0-940a-3129a7dff30a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def format_outputs(text):\n",
|
|
" formatted_text =f'''\n",
|
|
" <head>\n",
|
|
" <link rel=\"stylesheet\"\n",
|
|
" href=\"https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.0.3/styles/default.min.css\">\n",
|
|
" <script src=\"https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.0.3/highlight.min.js\"></script>\n",
|
|
" <script>hljs.initHighlightingOnLoad();</script>\n",
|
|
" </head>\n",
|
|
" <body>\n",
|
|
" <pre><code class=\"python\">{text}</code></pre>\n",
|
|
" </body>\n",
|
|
" '''\n",
|
|
" return formatted_text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "28abdab5-962e-47af-83b2-4f9c491ba705",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def generate_solution(question, starter_code=\"\", temperature=1., num_beams=1):\n",
|
|
" prompt = format_input(question, starter_code)\n",
|
|
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
|
|
" start = len(input_ids[0])\n",
|
|
" output = model.generate(\n",
|
|
" input_ids,\n",
|
|
" max_length=start+200,\n",
|
|
" do_sample=True,\n",
|
|
" top_p=0.95,\n",
|
|
" pad_token_id=tokenizer.pad_token_id,\n",
|
|
" early_stopping=True,\n",
|
|
" temperature=1.,\n",
|
|
" num_beams=int(num_beams),\n",
|
|
" no_repeat_ngram_size=None,\n",
|
|
" repetition_penalty=None,\n",
|
|
" num_return_sequences=None,\n",
|
|
" )\n",
|
|
" \n",
|
|
" return format_outputs(tokenizer.decode(output[0][start:]).strip())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "a156ff76-7e50-40c9-8b3e-e9ddeb450501",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"_EXAMPLES = [\n",
|
|
" [\n",
|
|
" \"\"\"\n",
|
|
"Given a 2D list of size `m * n`. Your task is to find the sum of minimum value in each row.\n",
|
|
"For Example:\n",
|
|
"```python\n",
|
|
"[\n",
|
|
" [1, 2, 3, 4, 5], # minimum value of row is 1\n",
|
|
" [5, 6, 7, 8, 9], # minimum value of row is 5\n",
|
|
" [20, 21, 34, 56, 100] # minimum value of row is 20\n",
|
|
"]\n",
|
|
"```\n",
|
|
"So, the function should return `26` because sum of minimums is as `1 + 5 + 20 = 26`\n",
|
|
" \"\"\",\n",
|
|
" \"\",\n",
|
|
" 0.8,\n",
|
|
" ],\n",
|
|
" [\n",
|
|
" \"\"\"\n",
|
|
"# Personalized greeting\n",
|
|
"\n",
|
|
"Create a function that gives a personalized greeting. This function takes two parameters: `name` and `owner`.\n",
|
|
" \"\"\",\n",
|
|
" \"\"\"\n",
|
|
"Use conditionals to return the proper message:\n",
|
|
"\n",
|
|
"case| return\n",
|
|
"--- | ---\n",
|
|
"name equals owner | 'Hello boss'\n",
|
|
"otherwise | 'Hello guest'\n",
|
|
"def greet(name, owner):\n",
|
|
" \"\"\",\n",
|
|
" 0.8,\n",
|
|
" ]\n",
|
|
"] "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "fb3d90fc-6932-4343-9c86-70ae94ca95aa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running locally at: http://127.0.0.1:7861/\n",
|
|
"This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)\n",
|
|
"Running on External URL: https://34711.gradio.app\n",
|
|
"Interface loading below...\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"\n",
|
|
" <iframe\n",
|
|
" width=\"900\"\n",
|
|
" height=\"500\"\n",
|
|
" src=\"https://34711.gradio.app\"\n",
|
|
" frameborder=\"0\"\n",
|
|
" allowfullscreen\n",
|
|
" ></iframe>\n",
|
|
" "
|
|
],
|
|
"text/plain": [
|
|
"<IPython.lib.display.IFrame at 0x7f0bf02db670>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(<Flask 'gradio.networking'>,\n",
|
|
" 'http://127.0.0.1:7861/',\n",
|
|
" 'https://34711.gradio.app')"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs = [\n",
|
|
" gr.inputs.Textbox(placeholder=\"Define a problem here...\", lines=7),\n",
|
|
" gr.inputs.Textbox(placeholder=\"Provide optional starter code...\", lines=3),\n",
|
|
" gr.inputs.Slider(0.5, 1.5, 0.1, default=0.8, label=\"Temperature\"),\n",
|
|
" gr.inputs.Slider(1,4,1,default=1, label=\"Beam size\")\n",
|
|
"]\n",
|
|
"\n",
|
|
"outputs = [\n",
|
|
" gr.outputs.HTML(label=\"Solution\")\n",
|
|
"]\n",
|
|
"\n",
|
|
"gr.Interface(\n",
|
|
" generate_solution, \n",
|
|
" inputs=inputs, \n",
|
|
" outputs=outputs,\n",
|
|
" title=\"Code Clippy: Problem Solver\",\n",
|
|
" examples=_EXAMPLES,\n",
|
|
").launch(share=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "737cd94c-5286-4832-9611-06e6f2a89357",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|