gpt-code-clippy/flax-gpt-neo-clm.ipynb
2021-07-05 21:24:02 +03:00

6627 lines
217 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uGYl4nCPKyZi"
},
"source": [
"# Fine-tuning GPTNeo with 🤗 Transformers and **Flax/JAX** on TPU\n",
"\n",
"In this notebook, we will see how to pretrain one of the [🤗 Transformers](https://github.com/huggingface/transformers) models on TPU using [**Flax**](https://flax.readthedocs.io/en/latest/index.html). \n",
"\n",
"GPTNeo causal language modeling objective will be used for pre-training here.\n",
"\n",
"As can be seen on [this benchmark](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.\n",
"\n",
"[**Flax**](https://flax.readthedocs.io/en/latest/index.html) is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as `grad` and `pmap` (see the [Flax philosophy](https://flax.readthedocs.io/en/latest/philosophy.html)). For an introduction to Flax see the [Flax Basic Colab](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) or the list of curated [Flax examples](https://flax.readthedocs.io/en/latest/examples.html).\n",
"\n",
"[**JAX**](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwDAzFXQMd46"
},
"source": [
"If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets, 🤗 Tokenizers as well as [Flax](https://github.com/google/flax.git) and [Optax](https://github.com/deepmind/optax). Optax is a gradient processing and optimization library for JAX, and is the optimizer library\n",
"recommended by Flax."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "If_SYBvU5V6u"
},
"source": [
"If everything is set up correctly, the following command should return a list of 8 TPU devices."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export HF_HOME=\"/home/shared/.cache/hf/\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3R5MP7PAbV7V",
"outputId": "9cf6b9a4-7b9c-4029-d938-dcf9a3ebb4c4"
},
"outputs": [
{
"data": {
"text/plain": [
"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
" TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
" TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
" TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
" TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
" TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
" TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
" TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"jax.local_devices()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vehXZCipMa1V"
},
"source": [
"In this notebook, we will pre-train an [autoregressive model](https://huggingface.co/transformers/model_summary.html#autoregressive-models) on one of the languages of the OSCAR corpus. [OSCAR](https://oscar-corpus.com/) is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the *goclassy* architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iz8HrV8JPHn0"
},
"source": [
"Let's first select the language that our model should learn.\n",
"You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the \"*File deduplicated*\" column on the official [OSCAR](https://oscar-corpus.com/) website.\n",
"\n",
"Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official `run_clm_flax.py` script offline that can be found [here](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#masked-language-modeling).\n",
"\n",
"Here we select `is` for Icelandic 🇮🇸."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "ii9XwLsmiY-E"
},
"outputs": [],
"source": [
"ds_name = \"code_search_net\"\n",
"language = \"python\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jVtv6T0oSjNq"
},
"source": [
"Next, we select the model architecture to be trained from scratch.\n",
"Here we choose [**`distilgpt2`**](https://huggingface.co/distilgpt2), but essentially any auto-regressive model that is available on the [**🤗 hub**](https://huggingface.co/models?filter=masked-lm,jax) in JAX/Flax can be used. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import Dropdown"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc69b75cca024772ab08e70c811dca94",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Dropdown(options=('EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B'), value='Ele…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ckpt_selector = Dropdown(options = [\"EleutherAI/gpt-neo-125M\", \"EleutherAI/gpt-neo-1.3B\", \"EleutherAI/gpt-neo-2.7B\"])\n",
"ckpt_selector"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "Sj1mJNJa6PPS"
},
"outputs": [],
"source": [
"model_ckpt = ckpt_selector.value"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j-tf_3Ch55_9"
},
"source": [
"## 1. Defining the model configuration\n",
"\n",
"To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory `\"distilgpt2-base-pretrained-is\"`:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "1dwuSvQxeM8-"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"model_dir = Path.home()/f\"{model_ckpt.split('/')[1]}-code-clippy\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qGENnc6LeRFL"
},
"source": [
"and create it:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oWQD8IA9eAFY"
},
"source": [
"Next, we'll download the model configuration:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "DO1SwHdi55en"
},
"outputs": [],
"source": [
"# from transformers import AutoConfig\n",
"\n",
"# config = AutoConfig.from_pretrained(model_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3exPFi-keYlT"
},
"source": [
" and save it to the directory:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "Vip8WKEp6b6Y"
},
"outputs": [],
"source": [
"# config.save_pretrained(f\"{model_dir}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "kJKw0tqOcDu6"
},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cQXZ1p5XHtP"
},
"source": [
"We will store our tokenizer files and model files in a directory, called `model_dir`. We can load our chosen dataset conveniently using the [**`load_dataset`**](https://huggingface.co/docs/datasets/package_reference/loading_methods.html?highlight=load_dataset#datasets.load_dataset) function."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5oUW__q-4If7",
"outputId": "4e5f1bd9-b6c1-42fe-ea21-c00b1c4ff47a"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:datasets.builder:Reusing dataset code_search_net (/home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27)\n"
]
}
],
"source": [
"dataset = load_dataset(ds_name, language)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "OCs_CQFt4WK_"
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4hD8d1_P5huo"
},
"source": [
"## 3. Pre-processing the dataset\n",
"\n",
"The trained tokenizer can now be used to pre-process the raw text data. \n",
"GPT2 was trained to generate tokens up to `1024` tokens, see paper [here](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).\n",
"However, since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 512 here. The raw text data is pre-processed accordingly."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "uDhqWoF-MAGv"
},
"outputs": [],
"source": [
"max_seq_length = 128 #512"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20000, 2000)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dataset['train']), len(dataset['validation'])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "aoXFHjEtwXWt"
},
"outputs": [],
"source": [
"# these cells should be commented out to run on full dataset\n",
"dataset[\"train\"] = dataset[\"train\"].select(range(20000))\n",
"dataset[\"validation\"] = dataset[\"validation\"].select(range(2000))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hYmElz46k7_E"
},
"source": [
"Next, we load the previously trained `ByteLevelBPETokenizer` tokenizer to pre-process the raw text data:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "wcpWIxX8dIAO"
},
"outputs": [],
"source": [
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"func_code_string\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lco7GkZ8nF-a"
},
"source": [
"and apply the tokenization function to every text sample via the convenient `map(...)` function of Datasets. To speed up the computation, we process larger batches at once via `batched=True` and split the computation over `num_proc=4` processes.\n",
"\n",
"**Note**: Running this command on the whole dataset might take up to 10 minutes ☕."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h6cjpFO2dTYC",
"outputId": "5b57b7aa-79f5-4780-95be-eddc79760f3b"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (2090 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2219 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (4716 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2795 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2736 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2943 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3776 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3817 > 2048). Running this sequence through the model will result in indexing errors\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-7b66142abafedf08.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-a010a7cca7e5eda5.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-e31ef47e93e1b72f.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-a0ccaa62536849b2.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-78eeda3c362d3d04.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-ef7a0c77b4162772.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-94163ac3c687abcb.arrow\n",
"WARNING:datasets.arrow_dataset:Loading cached processed dataset at /home/arto/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27/cache-a707561ad73af264.arrow\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (5852 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3702 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2478 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (5691 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (4533 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2182 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2240 > 2048). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2509 > 2048). Running this sequence through the model will result in indexing errors\n"
]
}
],
"source": [
"tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=8, remove_columns=dataset[\"train\"].column_names)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6_E0jsY9onEf"
},
"source": [
"The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size `max_seq_length=512` each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples.\n",
"Causal Language modeling simply consists of predicting the next token which means that the labels are essentially the inputs just shifted to the left. Thus, we copy the `input_ids` tensor and set it to `labels`.\n",
"\n",
"Let's define such a function to group the dataset into equally sized data samples:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "HO_neGynddat"
},
"outputs": [],
"source": [
"def group_texts(examples):\n",
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
" total_length = (total_length // max_seq_length) * max_seq_length\n",
" result = {\n",
" k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]\n",
" for k, t in concatenated_examples.items()\n",
" }\n",
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
" return result"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CR46Vvpwr6e5"
},
"source": [
"We pass `group_texts` to the `map(...)` function and set `batched=True` to make sure that the function is applied to a large batch of data samples. \n",
"\n",
"**Note**: Running this function on the whole dataset might take up to 50 minutes 🕒."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UmzNAUVediDa",
"outputId": "00f5fd1c-cb16-4539-f24e-a78d7164ff85"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"https://symbolize.stripped_domain/r/?trace=56a592,7f3f9d4bc20f,955ddf,3a0fc0f&map= \n",
"*** SIGTERM received by PID 361924 (TID 361924) on cpu 0 from PID 359912; stack trace: ***\n",
"PC: @ 0x56a592 (unknown) _PyEval_EvalFrameDefault\n",
" @ 0x7f3f8bb69800 976 (unknown)\n",
" @ 0x7f3f9d4bc210 (unknown) (unknown)\n",
" @ 0x955de0 (unknown) (unknown)\n",
" @ 0x3a0fc10 (unknown) (unknown)\n",
"https://symbolize.stripped_domain/r/?trace=56a592,7f3f8bb697ff,7f3f9d4bc20f,955ddf,3a0fc0f&map=2a762cd764e70bc90ae4c7f9747c08d7:7f3f7ec27000-7f3f8bea8280 \n",
"E0705 13:31:05.285333 361924 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
"E0705 13:31:05.301524 361924 process_state.cc:771] RAW: Raising signal 15 with default behavior\n"
]
}
],
"source": [
"tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=8)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"], axis=0)\n",
" d_z = np.diagonal(D)[2:][0]\n",
"\n",
" cond1 = np.all(d_xy < allowed_xy_movement)\n",
" cond2 = np.all([d_z[i] < allowed_z_movement[i]\n",
" for i in range(len(a))])\n",
"\n",
" \n"
]
}
],
"source": [
"import random\n",
"\n",
"id = random.randint(0, len(tokenized_datasets['train'])-1)\n",
"\n",
"example = tokenized_datasets['train'][id]\n",
"print(tokenizer.decode(example['input_ids']))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jid2JqXOsVfR"
},
"source": [
"Awesome, the data is now fully pre-processed and ready to be used for training 😎."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZRvfr609LzWu"
},
"source": [
"## 4. Pre-Training the model\n",
"\n",
"Now we will see how the power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.\n",
"\n",
"We need to import `jax`, `flax`, `optax`, `numpy` to define our training loop. Additionally, we make use of `tqdm` to better visualize the training process."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "5qOhue4Xm1TO"
},
"outputs": [],
"source": [
"import jax\n",
"import optax\n",
"import flax\n",
"import jax.numpy as jnp\n",
"import math\n",
"\n",
"from flax.training import train_state\n",
"from flax.training.common_utils import get_metrics, onehot, shard\n",
"\n",
"import numpy as np\n",
"\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_MGleTRG6Vor"
},
"source": [
"At first, we define all relevant hyper-parameters for pretraining in this notebook:\n",
"\n",
"- Each TPU will process a batch size of `16`\n",
"- The model is trained for `10` epochs\n",
"- The learning rate starts at `3e-4` and is successfully linearly decayed with each training step\n",
"- To reproduce the training run, a random seed is set to `0`.\n",
"\n",
"We can deduce the total batch size over all devices as well as the total number of training steps accordingly."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "y8lsJQy8liud"
},
"outputs": [],
"source": [
"per_device_batch_size = 1\n",
"num_train_epochs = 1\n",
"training_seed = 0\n",
"learning_rate = 3e-4\n",
"\n",
"total_batch_size = per_device_batch_size * jax.device_count()\n",
"num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_train_epochs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FB9bRDBq5j3r"
},
"source": [
"In the [official GPT2 paper](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) a batch size of 512 is used.\n",
"\n",
"Here, we use a batch size of `8 * 16 = 128` due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to `8 * 64 = 512`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i0Tylp115u1r"
},
"source": [
"Now we randomly initialized a `distilgpt2` model according to its configuration. To save memory and improve speed, we initialize the weights directly in `bfloat16` by setting `dtype=jnp.dtype(\"bfloat16\")`."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "aVr9TCzfacLN"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"tcmalloc: large alloc 5262319616 bytes == 0x90a30000 @ 0x7f3f9d686680 0x7f3f9d6a7824 0x5f7b11 0x648631 0x5c38e6 0x4f30e6 0x64ee88 0x505653 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332 0x568d9a 0x50b868 0x56bc9b 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf\n"
]
}
],
"source": [
"from transformers import FlaxAutoModelForCausalLM\n",
"\n",
"model = FlaxAutoModelForCausalLM.from_pretrained(model_ckpt, dtype=jnp.dtype(\"bfloat16\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sMS_QkT76Lgk"
},
"source": [
"Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click [here](https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup) for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.\n",
"\n",
"It is recommended to use the [**optax**](https://github.com/deepmind/optax) library for training utilities, *e.g.* learning rate schedules and optimizers.\n",
"\n",
"To see how to define a learning rate schedule with warmup, please take a look at the [official Flax CLM pre-training script](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kfBkuV1ck4rq"
},
"outputs": [],
"source": [
"# linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9555"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"steps_per_epoch = len(tokenized_datasets[\"train\"]) // total_batch_size\n",
"steps_per_epoch"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"955"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"warmup_steps = int(0.1 * num_train_steps)\n",
"warmup_steps"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_batch_size"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"def create_learning_rate_fn(\n",
" train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n",
") -> Callable[[int], jnp.array]:\n",
" \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n",
" steps_per_epoch = train_ds_size // train_batch_size\n",
" num_train_steps = steps_per_epoch * num_train_epochs\n",
" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n",
" decay_fn = optax.linear_schedule(\n",
" init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n",
" )\n",
" schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n",
" return schedule_fn"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"linear_decay_lr_schedule_fn = create_learning_rate_fn(\n",
" len(dataset['train']),\n",
" total_batch_size,\n",
" num_train_epochs,\n",
" warmup_steps,\n",
" learning_rate,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2p0yNxeU79F2"
},
"source": [
"We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay). \n",
"\n",
"AdamW can easily be imported from [optax](https://github.com/deepmind/optax) and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (*beta1*, *beta2*, *epsilon*) that are hard-coded in this notebook.\n",
"\n",
"For more information on AdamW (Adam + weight decay), one can take a look at [this](https://www.fast.ai/2018/07/02/adam-weight-decay/) blog post."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"id": "xRtpv_iamZd2"
},
"outputs": [],
"source": [
"adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"adafactor = optax.adafactor(learning_rate=3e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6g_fEbV-72Hc"
},
"source": [
"Next, we will create the *training state* that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.\n",
"\n",
"Most JAX transformations (notably [jax.jit](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see [Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)). As a consequence, Flax models (which can be transformed by JAX transformations) are **immutable**, and the state of the model (i.e., its weight parameters) is stored *outside* of the model instance.\n",
"\n",
"Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.\n",
"\n",
"Flax provides a convenience class [`flax.training.train_state.TrainState`](https://github.com/google/flax/blob/9da95cdd12591f42d2cd4c17089861bff7e43cc5/flax/training/train_state.py#L22), which stores things such as the model parameters, the loss function, the optimizer, and exposes an `apply_gradients` function to update the model's weight parameters.\n",
"\n",
"Alright, let's begin by defining our *training state* class. We create a `TrainState` class that stores the model's forward pass as the `apply_fn`, the `params`, and the AdamW optimizer."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"id": "JHYfR67AoKRc"
},
"outputs": [],
"source": [
"state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adafactor)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xiYCejDd81TX"
},
"source": [
"Next, let's implement a data loader for both training and evaluation.\n",
"The data loader can be defined as a [Python generator](https://wiki.python.org/moin/Generators) that returns a batch model input every time it is called.\n",
"\n",
"First, a random permutation of the whole dataset is defined. \n",
"Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"id": "Aos9GltTb3Ve"
},
"outputs": [],
"source": [
"def data_loader(rng, dataset, batch_size, shuffle=False):\n",
" steps_per_epoch = len(dataset) // batch_size\n",
"\n",
" if shuffle:\n",
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
" else:\n",
" batch_idx = jnp.arange(len(dataset))\n",
"\n",
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
"\n",
" for idx in batch_idx:\n",
" batch = dataset[idx]\n",
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
"\n",
" batch = shard(batch)\n",
"\n",
" yield batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7uoTXDLUzb-"
},
"source": [
"At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. \n",
"The indices for the training dataset are additionally randomly shuffled before each epoch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU6idLb29xYu"
},
"source": [
"During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch. \n",
"\n",
"Let's write the functions `train_step` and `eval_step` accordingly. During training the weight parameters should be updated as follows:\n",
"\n",
"1. Define a loss function `loss_function` that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). `loss_function` returns a scalar loss (using the previously defined `state.loss_function`) between the model output and input targets.\n",
"2. Differentiate this loss function using [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad). This is a JAX transformation called [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), which computes the gradient of `loss_function` given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair `(loss, gradients)`.\n",
"3. Compute the mean gradient over all devices using the collective operation [lax.pmean](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html). As we will see below, each device runs `train_step` on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.\n",
"4. Use `state.apply_gradients`, which applies the gradients to the weights.\n",
"\n",
"Below, you can see how each of the described steps above is put into practice.\n",
"\n",
"Also note that the `labels` are shifted one to the left and the last token of the `logits` is cut. This way, the model learns to predict the **next** token as defined in causal language modeling."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"id": "GjKzb0zJd-aH"
},
"outputs": [],
"source": [
"def train_step(state, batch, dropout_rng):\n",
" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)\n",
"\n",
" def loss_fn(params):\n",
" labels = batch.pop(\"labels\")\n",
" logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n",
" \n",
" loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n",
" return loss\n",
"\n",
" grad_fn = jax.value_and_grad(loss_fn)\n",
" loss, grad = grad_fn(state.params)\n",
" grad = jax.lax.pmean(grad, \"batch\")\n",
" new_state = state.apply_gradients(grads=grad)\n",
"\n",
" metrics = jax.lax.pmean(\n",
" {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}, axis_name=\"batch\"\n",
" )\n",
"\n",
" return new_state, metrics, new_dropout_rng"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nCPedI-B-FMQ"
},
"source": [
"Now, we want to do parallelized training over all TPU devices. To do so, we use [`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html?highlight=pmap#parallelization-pmap). This will compile the function once and run the same program on each device (it is an [SPMD program](https://en.wikipedia.org/wiki/SPMD)). When calling this pmapped function, all inputs (`\"state\"`, `\"batch\"`, `\"dropout_rng\"`) should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"id": "w3k1Lqerpw5k"
},
"outputs": [],
"source": [
"parallel_train_step = jax.pmap(train_step, \"batch\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DWFAZM6A8uf"
},
"source": [
"Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the next token loss is computed and stored in a `metric` dictionary during evaluation."
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"id": "EGEv7dyfpW4p"
},
"outputs": [],
"source": [
"def eval_step(params, batch):\n",
" labels = batch.pop(\"labels\")\n",
"\n",
" logits = model(**batch, params=params, train=False)[0]\n",
"\n",
" loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n",
"\n",
" # summarize metrics\n",
" metrics = {\"loss\": loss, \"perplexity\": jnp.exp(loss)}\n",
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
" return metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guaYWTvFA_66"
},
"source": [
"Similarly, we also apply `jax.pmap` to the evaluation step."
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"id": "0B8U2r2RpzjV"
},
"outputs": [],
"source": [
"parallel_eval_step = jax.pmap(eval_step, \"batch\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DLaM60PCY8Ka"
},
"source": [
"Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions."
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kncZTfALp3PG",
"outputId": "3ce8ee5a-7bda-4ba9-8774-c52a363a98f5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/arto/jenv/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n",
" warnings.warn(\n",
"/home/arto/jenv/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n",
" warnings.warn(\n"
]
}
],
"source": [
"state = flax.jax_utils.replicate(state)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2xg8oI-ZJ3P"
},
"source": [
"We can almost start training! In a final preparation step, we generate a seeded [**PRNGKey**](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax-random-prngkey) used as the random seed for dropout layers and dataset shuffling.\n",
"\n",
"Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one `PRNGKey` per device, which is why we split the initial `rng` key into 8 random seeds. "
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"id": "idu3E9ubqZH3"
},
"outputs": [],
"source": [
"rng = jax.random.PRNGKey(training_seed)\n",
"dropout_rngs = jax.random.split(rng, jax.local_device_count())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bKuMWHicbede"
},
"source": [
"Now, we are all set to finally start training! \n",
"Let's put all the pieces together and write the training loop. \n",
"\n",
"We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking. \n",
"\n",
"Next, we generate the training dataset indices.\n",
"In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step. \n",
"\n",
"Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.\n",
"\n",
"**Note**: It might seem that the following cell \"hangs\" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_loader = data_loader(input_rng, tokenized_datasets[\"validation\"], total_batch_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch = next(iter(eval_loader))\n",
"batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch['labels'].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# labels = batch.pop('labels')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_metric = parallel_eval_step(state.params, model_inputs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_metric"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_metric"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 427,
"referenced_widgets": [
"8c156c360afb4ddc962b87577e093cc4",
"9808f288735b4f5d9eea7377d4603ccc",
"68681edff0db43559012b6653571a289",
"125e8d57908f4679824c96d5247f7119",
"8134ef53fa8044c99b537386e88dbfc9",
"9b35ba10c85048ce8044b87ab9948611",
"e246db4008b14a23b55d2820be3b1ed8",
"b08557f6936e4af9a23f659649979c8d",
"dac789bcf81e4aac86717438ed6e8bbe",
"a1452333389441eab2c8016cd019d2f5",
"799b382dd6c74d528cf77813f82a7335",
"51d0dfb67bbd47ff88964bb446a91967",
"940ce122551e42509e8cec0659e47c77",
"a4234dc426d3459a93263c8b5fe5a34b",
"3549bf51a354408b9a39cb4532957617",
"7f05c774880548d2b73e69fb847c0a2e",
"71fa85dc266149e28685b1f5252185a5",
"9cb943192a154dbc8c7961f128449bbd",
"9de6ea64b67c443fbce1fd6aef4a7409",
"9658cc6977bb4fdf8a081aa99c300872",
"07b26680343f42fc805a3b52b398a5f2",
"97216855bb7f483882afae68bf78a51d",
"a1b9aa80b73748f18e3938df073b3d22",
"ce70b524ffab4052912d47b8adee6748",
"3c5df3c942ca4c768926b431064be351",
"f2e66aebdb4d4812a1f61c1ff990390a",
"269cc543a69e4ba3bae7f0bef67f1e25",
"8cb8398bb3c94284ac2ff42c7b3bc7c5",
"0b898451f79c4b6eb1678a8f366a5cbf",
"7a5426d6898e4ba2b5e3c48a65096232",
"2cf7d4c7dc6f413eb8cc9460e5de8d93",
"f7851bd9702949ee8e6278ad69c9d981",
"ab08bd64073d4d488f9c41e8595bac08",
"5e1805a3c7f24c3f9f15e86d15462d44",
"fc2e1a72620742499457ab2e31da12a9",
"b3b4d96626044375bad11ba48cc25ee3",
"c55e4e5f6d0a42b4ac25fb6d68a8e842",
"c7cca560cd4340bb986297655f513746",
"cd4892863d524637b8e567e2beed23bb",
"c85760cb6ccb4020be999357f2501bf9",
"7d0775fcf6a8497db20326dc94bd2739",
"5215d791d9144c089b6996336c3085cd",
"839daa7d5cdc4669942f1b5ef5bf4903",
"4bc03fce15a14355b3429c138ea1eadd",
"d7cddd2cafe84d17829096b31ee5e93e",
"b4ce8892d5784011a649eeca7b17fd0d",
"7a7b598979174b338f8fdfc90af61e47",
"cd946e84023445d59eea9893fba02f1a",
"4ec581224b3f47f2acdfaeb740c94526",
"81c25bb635db4bc9a62636bf3a3868eb",
"1462263d23c340e781d596a3f5414e03",
"c8bd20a456cc4783a2fbbdec9a2880ea",
"f0aa35a704a84967a3471e5f698bb04c",
"f2c541b4ecd04729a518414f8b878099",
"402795b6506a45c9831fb0ddb2ce0749",
"0ba00c3340414029938c9c947d4a74b6",
"d17f610741f142b782c74ac4e04e6481",
"84c8a7cbfdb04690941f8c32299744bb",
"f3f0b212907843f5be2131b7b60a955d",
"0a6a241c99bb49ad9baaef052b157964",
"b642c546813240a499a165279c743e7a",
"6acb415dc8b947bb85ab8f8959bc10e1",
"2c311c4506264995b7fbaba02333e747",
"2fde17ab63e5444799ec1dc5501d946e",
"6173f15a3e5a40819bd4f3833544fd78",
"da2f69c01bcd4139980a84ddca297817",
"b9470519dfff45228e54d0181337ee17",
"b5824ca248b74052bb1696b47369838e",
"11f8b1ddd5af4bb3babd492bbaf3fb7d",
"861e2a6f8e3e4a5cacfd0a746180cc90",
"181f0680e5bb4d23a69f9cc747207883",
"0e2531c122b44823a012ba844066e60c",
"a2324cab4e9745ddad69b35e77ed148e",
"42c8cb98eaae4cee86d49c675d3a3604",
"146ce87ac5b445e385285bdbbd70fba9",
"46340cbdaf9c47719e005b42b7236be2",
"cd26f91d46ef42609d265f0b709c3401",
"6f5f003a277244b2824e09b877fefa74",
"57a93be7bb3d4167a89f7d73d5915551",
"db96a8d67e564cd58358916ddfb86948",
"194f5ce951d54193b887f1cc91e3640b",
"d8770c573383466f81d7dfd01f048677",
"1fe6ad92e6834142844bb32e7a0d424d",
"9d905b40bf2746e79abad7257cf8654f",
"bfa957f8482847978542373db16cfa8a",
"1f836daeb9b7495a9ae00bc2c28fa212",
"0973226c646f4bb28f291ed3eb4e5cba",
"fd72a18adfb34bc2b2dab769067dcc1a",
"91fb9675236949e6bb8c99015061647c",
"75f988a47d774cd38fd85977fda7a96c",
"503ff750d98d4f2c8f7af32adc221685",
"8b3645c111b842f4b1d9ff54bb94f26e",
"dabd19b76eec4139a3028aa1776ac496",
"b6410389dafe43bd885446a6d2d255b0",
"23187b20efd44965966d6576c0d057aa",
"6f876daff243410bb26b8feb1cf981e0",
"2374ba361eaa41899b7eabe7da829259",
"522331eac9cb49e4998384f42eeba912",
"6eaf7a1f2e5e4aa590da55adf969ecc6",
"29b0ec79b8134168a5e3ae382088c147",
"96f9368df9dc453495badb09e767d091",
"02b1a546bc334efa93ba8dff5890dd0c",
"b7f8df98ab5d45f39b34ceb1018c01b5",
"c251cf1c33c741049cc9710e89f09bc0",
"7e067f02bdb74846a42909cd9ad1961b",
"68cf981867f7461c999c9e9f85af6499",
"b6eec0df5ca64d5dbf32cb5231127e3b",
"6efef5f4637e4efaa373cb9214d4e911",
"d539e4c93f4941e9af0b742df4c53ae0",
"dafd837bebd44c00ae68641cc0b31d3d",
"47808ac98d064711bccc77037e53dade",
"acfab7adb0ed440b87fa08f81086522d",
"6fa186bd796b4d16bfb67657178e4f65",
"07fe952dec2b4f2687bcddec8701a70e",
"6c9e7145c2d7439db6f63b29c505e63e",
"43861346cadc460992ab7e2c30aa556c",
"6baec27867e943a2bad6e70e48978ea5",
"d2e74c9d62444e98b52078d6a15dd068",
"fcf96c9bb04642089466de2421bd1af2",
"b9d1f631577f4a2a8eb49d3f8535ba25",
"e9c8d09e028741979a764789fd7f256e",
"2bd15c15e3b748cd9ee001bb86c04b6d",
"f5e27446f61747e1a65e477e2975dbcf",
"1e17af39c928463b800d22ad1fe533e5",
"98e148811880459ea7a7641ee3a301b6",
"6669a1844b394ea499dc92e126132bfc",
"987310405403439893813fb8c5da94a1",
"f48e444c929748b2a875c6eecc028bc4",
"1ed01b7ba2a64bebbbf0e641c8d1560e",
"bfadf2c9e99c48e4860c2a67b2955120",
"63f47b7d6bc849d1b869367a6ba8b40f",
"1bfe0a0ab5694e1d857ea981090d4f09",
"7f6628d5f61f424ba6373da2f6defe76",
"01d1c1e078e7403daabd1e7d8d5a4fda",
"3b802a49c0824b968d88350f82ff8422",
"5a6ea875aeff4688ba37481af5b5d9fe",
"6dfceed81b9b4f8198a8186dfdc72d42",
"44ee8edfb7c744358f0cb4567535ce07",
"5eb4c37fcf6a4b618a8f573b837a8c3a",
"5bc32858674f442fa7b49efcccfc3c15",
"3ebb7960558247ec84ef155042d2077a",
"52a25fde354e45dd9635097d22c16def",
"ce3117149c7e4ae6a93fbe959bd3f543",
"900aa99ccfeb4a38a0ad90f8ce656fe6",
"b414ff8df761462e8df8f51788f69508",
"e0c2ef0fd4d544b9b80b00689d21f503",
"52cae29e50974a42a3882f91b2a8f2ba",
"02b454a4d7ce44b0b819fc7085b1d357",
"6321a346b671411791979a1a343fb494",
"22941d4a640b4d8bba1b97f723da0ef0",
"b6f9da40231f4404bb0c90458c94c5c2",
"7db664fc575747af8e10ad6f2783a166",
"62fd4952400246189945f2de7f0192c6",
"f07d424e031d457a98a88b94b0c844b8",
"a9f8af0a5e594c29b553983b6a1c4c7d",
"5829f975dedd45d3b315ab29c1b12253",
"2eb6756f48074993825b5bfac47641e1",
"389fcde31cf24c438a9fcdb155060495",
"a5a86539a58f4fe98ba96f9cd3d46588",
"32c067d64cb7445eb0464f3a10071e7c",
"c014bf31dccb4448b561a702361ed9bc",
"ef823903b36348978833920bb7e03881",
"c6855514c5a649ebadc2db53c267e9e0",
"8602122a90db462cacef9373fca054fe",
"892e02ac20ef4963871d915085139f7d",
"173704b44e2c472499bc4242ef9e0863",
"2f647e8c1be2464fb9c69e8af5e5bf5a",
"5fc3a90707f746c884d924e385b0cc1c"
]
},
"id": "U946A-YZp-Pe",
"outputId": "1fe21ffd-2e39-4470-fb87-6b11ab1d4024"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "17551e2885a240c3be0f5ee31e96aead",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch ...: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2185e0ea86a040eda8f02fedf25cc57c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training...: 0%| | 0/9555 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_359912/842589707.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;31m# Model forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_metric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparallel_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mprogress_bar_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mf_pmapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1645\u001b[0m list(out_axes_leaves)))),\n\u001b[1;32m 1646\u001b[0m closure=(tuple(out_axes_leaves), out_axes_treedef))\n\u001b[0;32m-> 1647\u001b[0;31m out = pxla.xla_pmap(\n\u001b[0m\u001b[1;32m 1648\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1649\u001b[0m \u001b[0maxis_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlocal_axis_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobal_axis_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1618\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1619\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'in_axes'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1620\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1549\u001b[0m params_tuple, out_axes_transforms)\n\u001b[1;32m 1550\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1551\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1625\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 605\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 606\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 607\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mxla_pmap_impl\u001b[0;34m(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)\u001b[0m\n\u001b[1;32m 635\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"abstract args\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabstractify\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 636\u001b[0m (\"fingerprint\", fingerprint))\n\u001b[0;32m--> 637\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcompiled_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 638\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/jenv/lib/python3.8/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mexecute_replicated\u001b[0;34m(compiled, backend, in_handler, out_handler, *args)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexecute_replicated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompiled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1151\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_sharded_on_local_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1153\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mneeds_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1154\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbufs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for epoch in tqdm(range(1, num_train_epochs + 1), desc=f\"Epoch ...\", position=0, leave=True):\n",
" rng, input_rng = jax.random.split(rng)\n",
"\n",
" # -- Train --\n",
" train_loader = data_loader(input_rng, tokenized_datasets[\"train\"], total_batch_size, shuffle=True)\n",
" with tqdm(total=len(tokenized_datasets[\"train\"]) // total_batch_size, desc=\"Training...\", leave=False) as progress_bar_train:\n",
" for model_inputs in train_loader:\n",
" # Model forward\n",
" state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)\n",
"\n",
" progress_bar_train.update(1)\n",
"\n",
" progress_bar_train.write(\n",
" f\"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})\"\n",
" )\n",
"\n",
" # -- Eval --\n",
" eval_loader = data_loader(input_rng, tokenized_datasets[\"validation\"], total_batch_size)\n",
" eval_metrics = []\n",
" \n",
" with tqdm(total=len(tokenized_datasets[\"validation\"]) // total_batch_size, desc=\"Evaluation...\", leave=False) as progress_bar_eval:\n",
" for model_inputs in eval_loader:\n",
" # Model forward\n",
" eval_metric = parallel_eval_step(state.params, model_inputs)\n",
" eval_metrics.append(eval_metric)\n",
"\n",
" progress_bar_eval.update(1)\n",
" \n",
" eval_metrics = get_metrics(eval_metrics)\n",
" eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n",
" progress_bar_eval.write(\n",
" f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZI4XIhY-7hyh"
},
"source": [
"It can be seen that in this colab training already reaches a speed of 2.42 training steps per second. Executing [**`run_clm_flax.py`**](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling/run_clm_flax.py) on a TPUv3-8 VM should be as fast as 7 training steps per second.\n",
"\n",
"For a more in-detail comparison of runtimes please refer to [this](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) table."
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"name": "Causal Language Model Training on TPU with 🤗 Transformers & JAX",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"01d1c1e078e7403daabd1e7d8d5a4fda": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"02b1a546bc334efa93ba8dff5890dd0c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"02b454a4d7ce44b0b819fc7085b1d357": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7db664fc575747af8e10ad6f2783a166",
"placeholder": "",
"style": "IPY_MODEL_b6f9da40231f4404bb0c90458c94c5c2",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"07b26680343f42fc805a3b52b398a5f2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"07fe952dec2b4f2687bcddec8701a70e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"0973226c646f4bb28f291ed3eb4e5cba": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"0a6a241c99bb49ad9baaef052b157964": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2fde17ab63e5444799ec1dc5501d946e",
"placeholder": "",
"style": "IPY_MODEL_2c311c4506264995b7fbaba02333e747",
"value": " 137/137 [01:38&lt;00:00, 1.40it/s]"
}
},
"0b898451f79c4b6eb1678a8f366a5cbf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"0ba00c3340414029938c9c947d4a74b6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"0e2531c122b44823a012ba844066e60c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"11f8b1ddd5af4bb3babd492bbaf3fb7d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"125e8d57908f4679824c96d5247f7119": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b08557f6936e4af9a23f659649979c8d",
"placeholder": "",
"style": "IPY_MODEL_e246db4008b14a23b55d2820be3b1ed8",
"value": " 10/10 [19:05&lt;00:00, 114.56s/it]"
}
},
"1462263d23c340e781d596a3f5414e03": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_f2c541b4ecd04729a518414f8b878099",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_f0aa35a704a84967a3471e5f698bb04c",
"value": 12
}
},
"146ce87ac5b445e385285bdbbd70fba9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_6f5f003a277244b2824e09b877fefa74",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_cd26f91d46ef42609d265f0b709c3401",
"value": 137
}
},
"173704b44e2c472499bc4242ef9e0863": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"181f0680e5bb4d23a69f9cc747207883": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"194f5ce951d54193b887f1cc91e3640b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_1fe6ad92e6834142844bb32e7a0d424d",
"IPY_MODEL_9d905b40bf2746e79abad7257cf8654f"
],
"layout": "IPY_MODEL_d8770c573383466f81d7dfd01f048677"
}
},
"1bfe0a0ab5694e1d857ea981090d4f09": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5a6ea875aeff4688ba37481af5b5d9fe",
"placeholder": "",
"style": "IPY_MODEL_3b802a49c0824b968d88350f82ff8422",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"1e17af39c928463b800d22ad1fe533e5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f48e444c929748b2a875c6eecc028bc4",
"placeholder": "",
"style": "IPY_MODEL_987310405403439893813fb8c5da94a1",
"value": " 137/137 [01:36&lt;00:00, 1.41it/s]"
}
},
"1ed01b7ba2a64bebbbf0e641c8d1560e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_63f47b7d6bc849d1b869367a6ba8b40f",
"IPY_MODEL_1bfe0a0ab5694e1d857ea981090d4f09"
],
"layout": "IPY_MODEL_bfadf2c9e99c48e4860c2a67b2955120"
}
},
"1f836daeb9b7495a9ae00bc2c28fa212": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1fe6ad92e6834142844bb32e7a0d424d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_1f836daeb9b7495a9ae00bc2c28fa212",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_bfa957f8482847978542373db16cfa8a",
"value": 12
}
},
"22941d4a640b4d8bba1b97f723da0ef0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"23187b20efd44965966d6576c0d057aa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2374ba361eaa41899b7eabe7da829259": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6eaf7a1f2e5e4aa590da55adf969ecc6",
"IPY_MODEL_29b0ec79b8134168a5e3ae382088c147"
],
"layout": "IPY_MODEL_522331eac9cb49e4998384f42eeba912"
}
},
"269cc543a69e4ba3bae7f0bef67f1e25": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_7a5426d6898e4ba2b5e3c48a65096232",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_0b898451f79c4b6eb1678a8f366a5cbf",
"value": 137
}
},
"29b0ec79b8134168a5e3ae382088c147": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c251cf1c33c741049cc9710e89f09bc0",
"placeholder": "",
"style": "IPY_MODEL_b7f8df98ab5d45f39b34ceb1018c01b5",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"2bd15c15e3b748cd9ee001bb86c04b6d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2c311c4506264995b7fbaba02333e747": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2cf7d4c7dc6f413eb8cc9460e5de8d93": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2eb6756f48074993825b5bfac47641e1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"2f647e8c1be2464fb9c69e8af5e5bf5a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2fde17ab63e5444799ec1dc5501d946e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"32c067d64cb7445eb0464f3a10071e7c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3549bf51a354408b9a39cb4532957617": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"389fcde31cf24c438a9fcdb155060495": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3b802a49c0824b968d88350f82ff8422": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"3c5df3c942ca4c768926b431064be351": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_269cc543a69e4ba3bae7f0bef67f1e25",
"IPY_MODEL_8cb8398bb3c94284ac2ff42c7b3bc7c5"
],
"layout": "IPY_MODEL_f2e66aebdb4d4812a1f61c1ff990390a"
}
},
"3ebb7960558247ec84ef155042d2077a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"402795b6506a45c9831fb0ddb2ce0749": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"42c8cb98eaae4cee86d49c675d3a3604": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"43861346cadc460992ab7e2c30aa556c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b9d1f631577f4a2a8eb49d3f8535ba25",
"placeholder": "",
"style": "IPY_MODEL_fcf96c9bb04642089466de2421bd1af2",
"value": " 12/12 [00:05&lt;00:00, 2.33it/s]"
}
},
"44ee8edfb7c744358f0cb4567535ce07": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"46340cbdaf9c47719e005b42b7236be2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_db96a8d67e564cd58358916ddfb86948",
"placeholder": "",
"style": "IPY_MODEL_57a93be7bb3d4167a89f7d73d5915551",
"value": " 137/137 [01:37&lt;00:00, 1.40it/s]"
}
},
"47808ac98d064711bccc77037e53dade": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4bc03fce15a14355b3429c138ea1eadd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_cd946e84023445d59eea9893fba02f1a",
"placeholder": "",
"style": "IPY_MODEL_7a7b598979174b338f8fdfc90af61e47",
"value": " 137/137 [01:37&lt;00:00, 1.41it/s]"
}
},
"4ec581224b3f47f2acdfaeb740c94526": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_1462263d23c340e781d596a3f5414e03",
"IPY_MODEL_c8bd20a456cc4783a2fbbdec9a2880ea"
],
"layout": "IPY_MODEL_81c25bb635db4bc9a62636bf3a3868eb"
}
},
"503ff750d98d4f2c8f7af32adc221685": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_b6410389dafe43bd885446a6d2d255b0",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_dabd19b76eec4139a3028aa1776ac496",
"value": 137
}
},
"51d0dfb67bbd47ff88964bb446a91967": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7f05c774880548d2b73e69fb847c0a2e",
"placeholder": "",
"style": "IPY_MODEL_3549bf51a354408b9a39cb4532957617",
"value": " 137/137 [03:28&lt;00:00, 1.41it/s]"
}
},
"5215d791d9144c089b6996336c3085cd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"522331eac9cb49e4998384f42eeba912": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"52a25fde354e45dd9635097d22c16def": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"52cae29e50974a42a3882f91b2a8f2ba": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_22941d4a640b4d8bba1b97f723da0ef0",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_6321a346b671411791979a1a343fb494",
"value": 12
}
},
"57a93be7bb3d4167a89f7d73d5915551": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"5829f975dedd45d3b315ab29c1b12253": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_32c067d64cb7445eb0464f3a10071e7c",
"placeholder": "",
"style": "IPY_MODEL_a5a86539a58f4fe98ba96f9cd3d46588",
"value": " 137/137 [01:37&lt;00:00, 1.40it/s]"
}
},
"5a6ea875aeff4688ba37481af5b5d9fe": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5bc32858674f442fa7b49efcccfc3c15": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_900aa99ccfeb4a38a0ad90f8ce656fe6",
"placeholder": "",
"style": "IPY_MODEL_ce3117149c7e4ae6a93fbe959bd3f543",
"value": " 137/137 [01:37&lt;00:00, 1.42it/s]"
}
},
"5e1805a3c7f24c3f9f15e86d15462d44": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5eb4c37fcf6a4b618a8f573b837a8c3a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_52a25fde354e45dd9635097d22c16def",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_3ebb7960558247ec84ef155042d2077a",
"value": 137
}
},
"5fc3a90707f746c884d924e385b0cc1c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6173f15a3e5a40819bd4f3833544fd78": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_b9470519dfff45228e54d0181337ee17",
"IPY_MODEL_b5824ca248b74052bb1696b47369838e"
],
"layout": "IPY_MODEL_da2f69c01bcd4139980a84ddca297817"
}
},
"62fd4952400246189945f2de7f0192c6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_a9f8af0a5e594c29b553983b6a1c4c7d",
"IPY_MODEL_5829f975dedd45d3b315ab29c1b12253"
],
"layout": "IPY_MODEL_f07d424e031d457a98a88b94b0c844b8"
}
},
"6321a346b671411791979a1a343fb494": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"63f47b7d6bc849d1b869367a6ba8b40f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_01d1c1e078e7403daabd1e7d8d5a4fda",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_7f6628d5f61f424ba6373da2f6defe76",
"value": 12
}
},
"6669a1844b394ea499dc92e126132bfc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"68681edff0db43559012b6653571a289": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Epoch ...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_9b35ba10c85048ce8044b87ab9948611",
"max": 10,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_8134ef53fa8044c99b537386e88dbfc9",
"value": 10
}
},
"68cf981867f7461c999c9e9f85af6499": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6acb415dc8b947bb85ab8f8959bc10e1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6baec27867e943a2bad6e70e48978ea5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"6c9e7145c2d7439db6f63b29c505e63e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_d2e74c9d62444e98b52078d6a15dd068",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_6baec27867e943a2bad6e70e48978ea5",
"value": 12
}
},
"6dfceed81b9b4f8198a8186dfdc72d42": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_5eb4c37fcf6a4b618a8f573b837a8c3a",
"IPY_MODEL_5bc32858674f442fa7b49efcccfc3c15"
],
"layout": "IPY_MODEL_44ee8edfb7c744358f0cb4567535ce07"
}
},
"6eaf7a1f2e5e4aa590da55adf969ecc6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_02b1a546bc334efa93ba8dff5890dd0c",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_96f9368df9dc453495badb09e767d091",
"value": 12
}
},
"6efef5f4637e4efaa373cb9214d4e911": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_acfab7adb0ed440b87fa08f81086522d",
"placeholder": "",
"style": "IPY_MODEL_47808ac98d064711bccc77037e53dade",
"value": " 137/137 [01:36&lt;00:00, 1.42it/s]"
}
},
"6f5f003a277244b2824e09b877fefa74": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6f876daff243410bb26b8feb1cf981e0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6fa186bd796b4d16bfb67657178e4f65": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6c9e7145c2d7439db6f63b29c505e63e",
"IPY_MODEL_43861346cadc460992ab7e2c30aa556c"
],
"layout": "IPY_MODEL_07fe952dec2b4f2687bcddec8701a70e"
}
},
"71fa85dc266149e28685b1f5252185a5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9de6ea64b67c443fbce1fd6aef4a7409",
"IPY_MODEL_9658cc6977bb4fdf8a081aa99c300872"
],
"layout": "IPY_MODEL_9cb943192a154dbc8c7961f128449bbd"
}
},
"75f988a47d774cd38fd85977fda7a96c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"799b382dd6c74d528cf77813f82a7335": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_a4234dc426d3459a93263c8b5fe5a34b",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_940ce122551e42509e8cec0659e47c77",
"value": 137
}
},
"7a5426d6898e4ba2b5e3c48a65096232": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7a7b598979174b338f8fdfc90af61e47": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7d0775fcf6a8497db20326dc94bd2739": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_839daa7d5cdc4669942f1b5ef5bf4903",
"IPY_MODEL_4bc03fce15a14355b3429c138ea1eadd"
],
"layout": "IPY_MODEL_5215d791d9144c089b6996336c3085cd"
}
},
"7db664fc575747af8e10ad6f2783a166": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7e067f02bdb74846a42909cd9ad1961b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_b6eec0df5ca64d5dbf32cb5231127e3b",
"IPY_MODEL_6efef5f4637e4efaa373cb9214d4e911"
],
"layout": "IPY_MODEL_68cf981867f7461c999c9e9f85af6499"
}
},
"7f05c774880548d2b73e69fb847c0a2e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7f6628d5f61f424ba6373da2f6defe76": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"8134ef53fa8044c99b537386e88dbfc9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"81c25bb635db4bc9a62636bf3a3868eb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"839daa7d5cdc4669942f1b5ef5bf4903": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_b4ce8892d5784011a649eeca7b17fd0d",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_d7cddd2cafe84d17829096b31ee5e93e",
"value": 137
}
},
"84c8a7cbfdb04690941f8c32299744bb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8602122a90db462cacef9373fca054fe": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5fc3a90707f746c884d924e385b0cc1c",
"placeholder": "",
"style": "IPY_MODEL_2f647e8c1be2464fb9c69e8af5e5bf5a",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"861e2a6f8e3e4a5cacfd0a746180cc90": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"892e02ac20ef4963871d915085139f7d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"8b3645c111b842f4b1d9ff54bb94f26e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6f876daff243410bb26b8feb1cf981e0",
"placeholder": "",
"style": "IPY_MODEL_23187b20efd44965966d6576c0d057aa",
"value": " 137/137 [01:38&lt;00:00, 1.43it/s]"
}
},
"8c156c360afb4ddc962b87577e093cc4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_68681edff0db43559012b6653571a289",
"IPY_MODEL_125e8d57908f4679824c96d5247f7119"
],
"layout": "IPY_MODEL_9808f288735b4f5d9eea7377d4603ccc"
}
},
"8cb8398bb3c94284ac2ff42c7b3bc7c5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f7851bd9702949ee8e6278ad69c9d981",
"placeholder": "",
"style": "IPY_MODEL_2cf7d4c7dc6f413eb8cc9460e5de8d93",
"value": " 137/137 [01:37&lt;00:00, 1.41it/s]"
}
},
"900aa99ccfeb4a38a0ad90f8ce656fe6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"91fb9675236949e6bb8c99015061647c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_503ff750d98d4f2c8f7af32adc221685",
"IPY_MODEL_8b3645c111b842f4b1d9ff54bb94f26e"
],
"layout": "IPY_MODEL_75f988a47d774cd38fd85977fda7a96c"
}
},
"940ce122551e42509e8cec0659e47c77": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"9658cc6977bb4fdf8a081aa99c300872": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ce70b524ffab4052912d47b8adee6748",
"placeholder": "",
"style": "IPY_MODEL_a1b9aa80b73748f18e3938df073b3d22",
"value": " 12/12 [00:12&lt;00:00, 1.78it/s]"
}
},
"96f9368df9dc453495badb09e767d091": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"97216855bb7f483882afae68bf78a51d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9808f288735b4f5d9eea7377d4603ccc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"987310405403439893813fb8c5da94a1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"98e148811880459ea7a7641ee3a301b6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"9b35ba10c85048ce8044b87ab9948611": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9cb943192a154dbc8c7961f128449bbd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9d905b40bf2746e79abad7257cf8654f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fd72a18adfb34bc2b2dab769067dcc1a",
"placeholder": "",
"style": "IPY_MODEL_0973226c646f4bb28f291ed3eb4e5cba",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"9de6ea64b67c443fbce1fd6aef4a7409": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_97216855bb7f483882afae68bf78a51d",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_07b26680343f42fc805a3b52b398a5f2",
"value": 12
}
},
"a1452333389441eab2c8016cd019d2f5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a1b9aa80b73748f18e3938df073b3d22": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"a2324cab4e9745ddad69b35e77ed148e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_146ce87ac5b445e385285bdbbd70fba9",
"IPY_MODEL_46340cbdaf9c47719e005b42b7236be2"
],
"layout": "IPY_MODEL_42c8cb98eaae4cee86d49c675d3a3604"
}
},
"a4234dc426d3459a93263c8b5fe5a34b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a5a86539a58f4fe98ba96f9cd3d46588": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"a9f8af0a5e594c29b553983b6a1c4c7d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_389fcde31cf24c438a9fcdb155060495",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_2eb6756f48074993825b5bfac47641e1",
"value": 137
}
},
"ab08bd64073d4d488f9c41e8595bac08": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_fc2e1a72620742499457ab2e31da12a9",
"IPY_MODEL_b3b4d96626044375bad11ba48cc25ee3"
],
"layout": "IPY_MODEL_5e1805a3c7f24c3f9f15e86d15462d44"
}
},
"acfab7adb0ed440b87fa08f81086522d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b08557f6936e4af9a23f659649979c8d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b3b4d96626044375bad11ba48cc25ee3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c85760cb6ccb4020be999357f2501bf9",
"placeholder": "",
"style": "IPY_MODEL_cd4892863d524637b8e567e2beed23bb",
"value": " 12/12 [00:05&lt;00:00, 2.31it/s]"
}
},
"b414ff8df761462e8df8f51788f69508": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_52cae29e50974a42a3882f91b2a8f2ba",
"IPY_MODEL_02b454a4d7ce44b0b819fc7085b1d357"
],
"layout": "IPY_MODEL_e0c2ef0fd4d544b9b80b00689d21f503"
}
},
"b4ce8892d5784011a649eeca7b17fd0d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b5824ca248b74052bb1696b47369838e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0e2531c122b44823a012ba844066e60c",
"placeholder": "",
"style": "IPY_MODEL_181f0680e5bb4d23a69f9cc747207883",
"value": " 12/12 [00:05&lt;00:00, 2.34it/s]"
}
},
"b6410389dafe43bd885446a6d2d255b0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b642c546813240a499a165279c743e7a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"b6eec0df5ca64d5dbf32cb5231127e3b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_dafd837bebd44c00ae68641cc0b31d3d",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_d539e4c93f4941e9af0b742df4c53ae0",
"value": 137
}
},
"b6f9da40231f4404bb0c90458c94c5c2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"b7f8df98ab5d45f39b34ceb1018c01b5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"b9470519dfff45228e54d0181337ee17": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_861e2a6f8e3e4a5cacfd0a746180cc90",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_11f8b1ddd5af4bb3babd492bbaf3fb7d",
"value": 12
}
},
"b9d1f631577f4a2a8eb49d3f8535ba25": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bfa957f8482847978542373db16cfa8a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"bfadf2c9e99c48e4860c2a67b2955120": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c014bf31dccb4448b561a702361ed9bc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_c6855514c5a649ebadc2db53c267e9e0",
"IPY_MODEL_8602122a90db462cacef9373fca054fe"
],
"layout": "IPY_MODEL_ef823903b36348978833920bb7e03881"
}
},
"c251cf1c33c741049cc9710e89f09bc0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c55e4e5f6d0a42b4ac25fb6d68a8e842": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"c6855514c5a649ebadc2db53c267e9e0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_173704b44e2c472499bc4242ef9e0863",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_892e02ac20ef4963871d915085139f7d",
"value": 12
}
},
"c7cca560cd4340bb986297655f513746": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c85760cb6ccb4020be999357f2501bf9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c8bd20a456cc4783a2fbbdec9a2880ea": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0ba00c3340414029938c9c947d4a74b6",
"placeholder": "",
"style": "IPY_MODEL_402795b6506a45c9831fb0ddb2ce0749",
"value": " 12/12 [00:05&lt;00:00, 2.35it/s]"
}
},
"cd26f91d46ef42609d265f0b709c3401": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"cd4892863d524637b8e567e2beed23bb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"cd946e84023445d59eea9893fba02f1a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ce3117149c7e4ae6a93fbe959bd3f543": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"ce70b524ffab4052912d47b8adee6748": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d17f610741f142b782c74ac4e04e6481": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_f3f0b212907843f5be2131b7b60a955d",
"IPY_MODEL_0a6a241c99bb49ad9baaef052b157964"
],
"layout": "IPY_MODEL_84c8a7cbfdb04690941f8c32299744bb"
}
},
"d2e74c9d62444e98b52078d6a15dd068": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d539e4c93f4941e9af0b742df4c53ae0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"d7cddd2cafe84d17829096b31ee5e93e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"d8770c573383466f81d7dfd01f048677": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"da2f69c01bcd4139980a84ddca297817": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"dabd19b76eec4139a3028aa1776ac496": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"dac789bcf81e4aac86717438ed6e8bbe": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_799b382dd6c74d528cf77813f82a7335",
"IPY_MODEL_51d0dfb67bbd47ff88964bb446a91967"
],
"layout": "IPY_MODEL_a1452333389441eab2c8016cd019d2f5"
}
},
"dafd837bebd44c00ae68641cc0b31d3d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"db96a8d67e564cd58358916ddfb86948": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e0c2ef0fd4d544b9b80b00689d21f503": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e246db4008b14a23b55d2820be3b1ed8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"e9c8d09e028741979a764789fd7f256e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_f5e27446f61747e1a65e477e2975dbcf",
"IPY_MODEL_1e17af39c928463b800d22ad1fe533e5"
],
"layout": "IPY_MODEL_2bd15c15e3b748cd9ee001bb86c04b6d"
}
},
"ef823903b36348978833920bb7e03881": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f07d424e031d457a98a88b94b0c844b8": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f0aa35a704a84967a3471e5f698bb04c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"f2c541b4ecd04729a518414f8b878099": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f2e66aebdb4d4812a1f61c1ff990390a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f3f0b212907843f5be2131b7b60a955d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_6acb415dc8b947bb85ab8f8959bc10e1",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b642c546813240a499a165279c743e7a",
"value": 137
}
},
"f48e444c929748b2a875c6eecc028bc4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f5e27446f61747e1a65e477e2975dbcf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Training...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_6669a1844b394ea499dc92e126132bfc",
"max": 137,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_98e148811880459ea7a7641ee3a301b6",
"value": 137
}
},
"f7851bd9702949ee8e6278ad69c9d981": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"fc2e1a72620742499457ab2e31da12a9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "Evaluation...: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_c7cca560cd4340bb986297655f513746",
"max": 12,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c55e4e5f6d0a42b4ac25fb6d68a8e842",
"value": 12
}
},
"fcf96c9bb04642089466de2421bd1af2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"fd72a18adfb34bc2b2dab769067dcc1a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}