Sketchback/sketchback.ipynb

1211 lines
702 KiB
Plaintext
Raw Permalink Normal View History

2017-06-01 18:46:09 +03:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from __future__ import print_function\n",
"import numpy as np\n",
"import pandas as pd \n",
"import cv2 as cv\n",
"import os\n",
"import h5py\n",
"import matplotlib.pyplot as plt\n",
"import scipy.misc\n",
"import scipy.ndimage\n",
"\n",
"from tqdm import tqdm\n",
"from copy import deepcopy\n",
"from sklearn.preprocessing import StandardScaler\n",
"from scipy.misc import imresize\n",
"from keras.models import Sequential, Model\n",
"from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, ZeroPadding2D, Convolution2D, Deconvolution2D, merge\n",
"from keras.layers.core import Activation, Dropout, Flatten, Lambda\n",
"from keras.layers.normalization import BatchNormalization\n",
"from keras.optimizers import SGD, Adam, Nadam\n",
"from keras.utils import np_utils, plot_model\n",
"from keras.callbacks import TensorBoard\n",
"from keras import objectives, layers\n",
"from keras.applications import vgg16\n",
"from keras.applications.vgg16 import preprocess_input\n",
"from keras import backend as K\n",
"\n",
"np.random.seed(1337) # for reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# CelebA Faces: 72x88 200K Images\n",
"# ZuBuD Buildings: 120x160 3K Images\n",
"# CUHK Faces: 80x112 88 Images\n",
"\n",
"m = 205\n",
"n = 282\n",
"sketch_dim = (m,n)\n",
"img_dim = (m,n,3)\n",
"num_images = 3000\n",
"num_epochs = 20\n",
"batch_size = 5\n",
"file_names = []\n",
"\n",
"CelebA_SKETCH_PATH = '/home/balkhamissi/Desktop/Project/CelebA_Sketch'\n",
"CelebA_IMAGE_PATH = '/home/balkhamissi/Desktop/Project/img_align_celeba'\n",
"\n",
"BUILDING_SKETCH_PATH = '/home/balkhamissi/Desktop/Project/ZuBuD_Sketch_Aug'\n",
"BUILDING_IMAGE_PATH = '/home/balkhamissi/Desktop/Project/ZuBuD_Aug'\n",
"\n",
"CUHK_SKETCH_PATH = '/home/balkhamissi/Desktop/Project/CUHK_Sketch'\n",
"CUHK_IMAGE_PATH = '/home/balkhamissi/Desktop/Project/CUHK'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/balkhamissi/anaconda3/lib/python2.7/site-packages/ipykernel_launcher.py:2: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=Tensor(\"bl..., inputs=Tensor(\"in...)`\n",
" \n"
]
}
],
"source": [
"base_model = vgg16.VGG16(weights='imagenet', include_top=False)\n",
"vgg = Model(input=base_model.input, output=base_model.get_layer('block2_conv2').output)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_file_names(path):\n",
" return os.listdir(path)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def sub_plot(x,y,z):\n",
" fig = plt.figure()\n",
" a = fig.add_subplot(1,3,1)\n",
" imgplot = plt.imshow(x, cmap='gray')\n",
" a.set_title('Sketch')\n",
" plt.axis(\"off\")\n",
" a = fig.add_subplot(1,3,2)\n",
" imgplot = plt.imshow(z)\n",
" a.set_title('Prediction')\n",
" plt.axis(\"off\")\n",
" a = fig.add_subplot(1,3,3)\n",
" imgplot = plt.imshow(y)\n",
" a.set_title('Ground Truth')\n",
" plt.axis(\"off\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def imshow(x, gray=False):\n",
" plt.imshow(x, cmap='gray' if gray else None)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_batch(idx, X = True, Y = True, W = True, dataset='zubud'):\n",
" \n",
" global file_names\n",
"\n",
" X_train = np.zeros((batch_size, m, n), dtype='float32')\n",
" Y_train = np.zeros((batch_size, m, n, 3), dtype='float32')\n",
" F_train = None\n",
" \n",
" if dataset == 'zubud':\n",
" x_path = BUILDING_SKETCH_PATH\n",
" y_path = BUILDING_IMAGE_PATH\n",
" elif dataset == 'cuhk':\n",
" x_path = CUHK_SKETCH_PATH\n",
" y_path = CUHK_IMAGE_PATH\n",
" else:\n",
" x_path = CelebA_SKETCH_PATH\n",
" y_path = CelebA_IMAGE_PATH\n",
" \n",
" if len(file_names) == 0:\n",
" file_names = load_file_names(x_path)\n",
" \n",
" if X:\n",
" # Load Sketches\n",
" for i in range(batch_size):\n",
" file = os.path.join(x_path, file_names[i+batch_size*idx])\n",
" img = cv.imread(file,0)\n",
" img = imresize(img, sketch_dim)\n",
" img = img.astype('float32')\n",
" X_train[i] = img / 255.\n",
" \n",
" if Y:\n",
" # Load Ground-truth Images\n",
" for i in range(batch_size):\n",
" file = os.path.join(y_path, file_names[i+batch_size*idx])\n",
" img = cv.imread(file)\n",
" img = imresize(img, img_dim)\n",
" if dataset != 'zubud':\n",
" img = cv.cvtColor(img, cv.COLOR_BGR2RGB)\n",
" img = img.astype('float32')\n",
" Y_train[i] = img / 255.\n",
" \n",
" if W:\n",
" F_train = get_features(Y_train)\n",
" \n",
" X_train = np.reshape(X_train, (batch_size, m, n, 1))\n",
" return X_train, Y_train, F_train"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_features(Y):\n",
" Z = deepcopy(Y)\n",
" Z = preprocess_vgg(Z)\n",
" features = vgg.predict(Z, batch_size = 5, verbose = 0)\n",
" return features"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def preprocess_vgg(x, data_format=None):\n",
" if data_format is None:\n",
" data_format = K.image_data_format()\n",
" assert data_format in {'channels_last', 'channels_first'}\n",
" x = 255. * x\n",
" if data_format == 'channels_first':\n",
" # 'RGB'->'BGR'\n",
" x = x[:, ::-1, :, :]\n",
" # Zero-center by mean pixel\n",
" x[:, 0, :, :] = x[:, 0, :, :] - 103.939\n",
" x[:, 1, :, :] = x[:, 1, :, :] - 116.779\n",
" x[:, 2, :, :] = x[:, 2, :, :] - 123.68\n",
" else:\n",
" # 'RGB'->'BGR'\n",
" x = x[:, :, :, ::-1]\n",
" # Zero-center by mean pixel\n",
" x[:, :, :, 0] = x[:, :, :, 0] - 103.939\n",
" x[:, :, :, 1] = x[:, :, :, 1] - 116.779\n",
" x[:, :, :, 2] = x[:, :, :, 2] - 123.68\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def feature_loss(y_true, y_pred):\n",
" return K.sqrt(K.mean(K.square(y_true - y_pred)))\n",
"\n",
"def pixel_loss(y_true, y_pred):\n",
" return K.sqrt(K.mean(K.square(y_true - y_pred))) + 0.00001*total_variation_loss(y_pred)\n",
"\n",
"def adv_loss(y_true, y_pred):\n",
" return K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1)\n",
"\n",
"def total_variation_loss(y_pred):\n",
" if K.image_data_format() == 'channels_first':\n",
" a = K.square(y_pred[:, :, :m - 1, :n - 1] - y_pred[:, :, 1:, :n - 1])\n",
" b = K.square(y_pred[:, :, :m - 1, :n - 1] - y_pred[:, :, :m - 1, 1:])\n",
" else:\n",
" a = K.square(y_pred[:, :m - 1, :n - 1, :] - y_pred[:, 1:, :n - 1, :])\n",
" b = K.square(y_pred[:, :m - 1, :n - 1, :] - y_pred[:, :m - 1, 1:, :])\n",
" return K.sum(K.pow(a + b, 1.25))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def preprocess_VGG(x, dim_ordering='default'):\n",
" if dim_ordering == 'default':\n",
" dim_ordering = K.image_dim_ordering()\n",
" assert dim_ordering in {'tf', 'th'}\n",
" # x has pixels intensities between 0 and 1\n",
" x = 255. * x\n",
" norm_vec = K.variable([103.939, 116.779, 123.68])\n",
" if dim_ordering == 'th':\n",
" norm_vec = K.reshape(norm_vec, (1,3,1,1))\n",
" x = x - norm_vec\n",
" # 'RGB'->'BGR'\n",
" x = x[:, ::-1, :, :]\n",
" else:\n",
" norm_vec = K.reshape(norm_vec, (1,1,1,3))\n",
" x = x - norm_vec\n",
" # 'RGB'->'BGR'\n",
" x = x[:, :, :, ::-1]\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def generator_model(input_img):\n",
"\n",
" # Encoder\n",
" x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)\n",
" x = Conv2D(32, (2, 2), activation='relu', padding='same')(x)\n",
" x = MaxPooling2D((2, 2), padding='same')(x)\n",
"\n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)\n",
" x = Conv2D(64, (2, 2), activation='relu', padding='same')(x)\n",
" x = MaxPooling2D((2, 2), padding='same')(x)\n",
"\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)\n",
" x = MaxPooling2D((2, 2), padding='same')(x)\n",
"\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)\n",
" res = Conv2D(256, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
" res = Conv2D(256, (3, 3), activation='relu', padding='same')(x)\n",
" encoded = layers.add([x, res])\n",
"\n",
" # Decoder\n",
" res = Conv2D(256, (3, 3), activation='relu', padding='same', name='block5_conv1')(encoded)\n",
" x = layers.add([encoded, res])\n",
" res = Conv2D(256, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
" res = Conv2D(256, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
"\n",
" x = Conv2D(128, (2, 2), activation='relu', padding='same', name='block6_conv1')(x)\n",
" x = UpSampling2D((2, 2))(x)\n",
"\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block7_conv1')(x)\n",
" res = Conv2D(128, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
" res = Conv2D(128, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
"\n",
" x = Conv2D(64, (2, 2), activation='relu', padding='same', name='block8_conv1')(x)\n",
" x = UpSampling2D((2, 2))(x)\n",
"\n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv1')(x)\n",
" res = Conv2D(64, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
" res = Conv2D(64, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
"\n",
" x = Conv2D(32, (2, 2), activation='relu', padding='same', name='block10_conv1')(x)\n",
" x = UpSampling2D((2, 2))(x)\n",
"\n",
" x = Conv2D(32, (3, 3), activation='relu', padding='same', name='block11_conv1')(x)\n",
" res = Conv2D(32, (3, 3), activation='relu', padding='same')(x)\n",
" x = layers.add([x, res])\n",
" decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)\n",
" \n",
" return decoded"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def generator_model_2(input_img):\n",
" x = Convolution2D(32, (9, 9), padding=\"same\", strides=(1,1))(input_img)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"relu\")(x)\n",
" \n",
" x = Convolution2D(64, (3, 3), padding=\"same\", strides=(2,2))(x)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"relu\")(x)\n",
" \n",
" x = Convolution2D(128, (3, 3), padding=\"same\", strides=(2,2))(x)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"relu\")(x)\n",
" \n",
" # then 5 res blocks\n",
" \n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(x)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(r)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" # Merge residual and identity\n",
" x = merge([x, r], mode='sum', concat_axis=1)\n",
" \n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(x)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(r)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" # Merge residual and identity\n",
" x = merge([x, r], mode='sum', concat_axis=1)\n",
" \n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(x)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(r)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" # Merge residual and identity\n",
" x = merge([x, r], mode='sum', concat_axis=1)\n",
" \n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(x)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(r)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" # Merge residual and identity\n",
" x = merge([x, r], mode='sum', concat_axis=1)\n",
" \n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(x)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" r = Convolution2D(128, (3, 3), padding=\"same\")(r)\n",
" r = BatchNormalization(axis=1)(r)\n",
" r = Activation(\"relu\")(r)\n",
" # Merge residual and identity\n",
" x = merge([x, r], mode='sum', concat_axis=1)\n",
" \n",
" # the 2 deconv blocks\n",
" x = Deconvolution2D(64, (3, 3), output_shape=(batch_size, m/2, n/2, 64), padding='same', strides=(2,2))(x)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"relu\")(x)\n",
" \n",
" x = Deconvolution2D(32, (3, 3), output_shape=(batch_size, m/2, n/2, 32), padding='same', strides=(2,2))(x)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"relu\")(x)\n",
"\n",
" # final conv block\n",
" x = Convolution2D(3, (9, 9), padding=\"same\", strides=(1,1))(x)\n",
" x = BatchNormalization(axis=1)(x)\n",
" x = Activation(\"sigmoid\")(x)\n",
" \n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def discriminator_model(img_input):\n",
" \n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='d_block1_conv1')(img_input)\n",
" x = Conv2D(64, (3, 3), activation='relu', padding='same', name='d_block1_conv2')(x)\n",
" x = MaxPooling2D((2, 2), strides=(2, 2), name='d_block1_pool')(x)\n",
"\n",
" # Block 2\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='d_block2_conv1')(x)\n",
" x = Conv2D(128, (3, 3), activation='relu', padding='same', name='d_block2_conv2')(x)\n",
" x = MaxPooling2D((2, 2), strides=(2, 2), name='d_block2_pool')(x)\n",
"\n",
" # Block 3\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='d_block3_conv1')(x)\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='d_block3_conv2')(x)\n",
" x = Conv2D(256, (3, 3), activation='relu', padding='same', name='d_block3_conv3')(x)\n",
" x = MaxPooling2D((2, 2), strides=(2, 2), name='d_block3_pool')(x)\n",
"\n",
" # Block 4\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block4_conv1')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block4_conv2')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block4_conv3')(x)\n",
" x = MaxPooling2D((2, 2), strides=(2, 2), name='d_block4_pool')(x)\n",
"\n",
" # Block 5\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block5_conv1')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block5_conv2')(x)\n",
" x = Conv2D(512, (3, 3), activation='relu', padding='same', name='d_block5_conv3')(x)\n",
" x = MaxPooling2D((2, 2), strides=(2, 2), name='d_block5_pool')(x)\n",
" \n",
" x = Flatten(name='flatten')(x)\n",
" x = Dense(512, activation='relu', name='d_fc1')(x)\n",
" x = Dropout(0.5)(x)\n",
" x = Dense(1, activation='relu', name='d_fc2')(x)\n",
" model = Model(input=img_input, output=x)\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def feat_model(img_input):\n",
" # extract vgg feature\n",
" vgg_16 = vgg16.VGG16(include_top=False, weights='imagenet', input_tensor=None)\n",
" # freeze VGG_16 when training\n",
" for layer in vgg_16.layers:\n",
" layer.trainable = False\n",
" \n",
" vgg_first2 = Model(input=vgg_16.input, output=vgg_16.get_layer('block2_conv2').output)\n",
" Norm_layer = Lambda(preprocess_VGG)\n",
" x_VGG = Norm_layer(img_input)\n",
" feat = vgg_first2(x_VGG)\n",
" return feat"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def full_model(summary = True):\n",
" input_img = Input(shape=(m, n, 1))\n",
" generator = generator_model(input_img)\n",
" feat = feat_model(generator)\n",
" model = Model(input=input_img, output=[generator, feat], name='architect')\n",
" model.summary()\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_gen_model():\n",
" gen_model = full_model()\n",
" model = Model(input=gen_model.input, output=gen_model.get_layer('block2_conv1').output)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def train_full_model():\n",
" generator = generator_model(True)\n",
" discriminator = discriminator_model()\n",
"\n",
" full_model = Sequential()\n",
" full_model.add(generator)\n",
" full_model.add(discriminator)\n",
"\n",
" def loss(y_true, y_pred):\n",
" return 1 - discriminator.predict(y_pred)\n",
"\n",
" generator.compile(loss=loss, optimizer='adam')\n",
" full_model.compile(loss='binary_crossentropy', optimizer='adam')\n",
" discriminator.compile(loss='binary_crossentropy', optimizer='adam')\n",
" #128\n",
" for epoch in num_epochs:\n",
" num_batches = num_images // batch_size\n",
"\n",
" for batch in num_batches:\n",
" X,Y = get_batch(batch)\n",
"\n",
" Y_pred = generator.predict(X) \n",
" discriminator_Y = [0] * batch_size + [1] * batch_size\n",
" discriminator_X = np.concatenate(Y_pred, Y)\n",
" discriminator.trainable=True\n",
" discr_loss = discriminator.fit_on_batch(discriminator_X, discriminator_Y)\n",
"\n",
" discriminator.trainable=False\n",
"\n",
" generator_loss = generator.fit_on_batch(X, Y)\n",
"\n",
" generator.save_weights(generator,True)\n",
" discriminator.save_weights(discriminator, True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def compute_vgg():\n",
" base_model = vgg16.VGG16(weights='imagenet', include_top=False)\n",
" model = Model(input=base_model.input, output=base_model.get_layer('block2_conv2').output)\n",
" num_batches = num_images // batch_size\n",
" for batch in range(num_batches):\n",
" _, Y = get_batch(batch, X = False);\n",
" Y = preprocess_vgg(Y)\n",
" features = model.predict(Y, verbose = 1)\n",
" f = h5py.File('features/feat_%d' % batch, \"w\")\n",
" dset = f.create_dataset(\"features\", data=features)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# model = get_full_model()\n",
"\n",
"# adam = Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=1e-8)\n",
"# model.compile(loss=[pixel_loss, feature_loss, adv_loss], loss_weights=[1, 1, 1], optimizer=adam)\n",
"# file_names = load_file_names(IMAGE_PATH)\n",
"# model.summary()\n",
"# # Threshold / Discriminator Starts at Layer #41\n",
"\n",
"\n",
"# for i, layer in enumerate(model.layers):\n",
"# print i, layer.name\n",
" \n",
"# sub_batch_size = 5\n",
"# for epoch in range(num_epochs):\n",
"# num_batches = num_images // batch_size\n",
"\n",
"# for batch in range(num_batches):\n",
"# X,Y,W = get_batch(batch)\n",
"# D = batch_size*[0]\n",
"# print \"training on batch %d\" % batch\n",
"# for layer in model.layers[41:]:\n",
"# layer.trainable = False\n",
"# history = model.fit(X, [Y,W,D], verbose = True, shuffle=\"batch\", epochs = 1, batch_size=sub_batch_size)\n",
" \n",
" \n",
"\n",
"# model.save_weights(\"weights_2_%d_%d\" % (epoch, batch))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"____________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"====================================================================================================\n",
"input_10 (InputLayer) (None, 205, 282, 1) 0 \n",
"____________________________________________________________________________________________________\n",
"conv2d_57 (Conv2D) (None, 205, 282, 32) 320 input_10[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_58 (Conv2D) (None, 205, 282, 32) 4128 conv2d_57[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_13 (MaxPooling2D) (None, 103, 141, 32) 0 conv2d_58[0][0] \n",
"____________________________________________________________________________________________________\n",
"block2_conv1 (Conv2D) (None, 103, 141, 64) 18496 max_pooling2d_13[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_59 (Conv2D) (None, 103, 141, 64) 16448 block2_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_14 (MaxPooling2D) (None, 52, 71, 64) 0 conv2d_59[0][0] \n",
"____________________________________________________________________________________________________\n",
"block3_conv1 (Conv2D) (None, 52, 71, 128) 73856 max_pooling2d_14[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_60 (Conv2D) (None, 52, 71, 128) 147584 block3_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_15 (MaxPooling2D) (None, 26, 36, 128) 0 conv2d_60[0][0] \n",
"____________________________________________________________________________________________________\n",
"block4_conv1 (Conv2D) (None, 26, 36, 256) 295168 max_pooling2d_15[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_61 (Conv2D) (None, 26, 36, 256) 590080 block4_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_41 (Add) (None, 26, 36, 256) 0 block4_conv1[0][0] \n",
" conv2d_61[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_62 (Conv2D) (None, 26, 36, 256) 590080 add_41[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_42 (Add) (None, 26, 36, 256) 0 add_41[0][0] \n",
" conv2d_62[0][0] \n",
"____________________________________________________________________________________________________\n",
"block5_conv1 (Conv2D) (None, 26, 36, 256) 590080 add_42[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_43 (Add) (None, 26, 36, 256) 0 add_42[0][0] \n",
" block5_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_63 (Conv2D) (None, 26, 36, 256) 590080 add_43[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_44 (Add) (None, 26, 36, 256) 0 add_43[0][0] \n",
" conv2d_63[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_64 (Conv2D) (None, 26, 36, 256) 590080 add_44[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_45 (Add) (None, 26, 36, 256) 0 add_44[0][0] \n",
" conv2d_64[0][0] \n",
"____________________________________________________________________________________________________\n",
"block6_conv1 (Conv2D) (None, 26, 36, 128) 131200 add_45[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_13 (UpSampling2D) (None, 52, 72, 128) 0 block6_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block7_conv1 (Conv2D) (None, 52, 72, 128) 147584 up_sampling2d_13[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_65 (Conv2D) (None, 52, 72, 128) 147584 block7_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_46 (Add) (None, 52, 72, 128) 0 block7_conv1[0][0] \n",
" conv2d_65[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_66 (Conv2D) (None, 52, 72, 128) 147584 add_46[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_47 (Add) (None, 52, 72, 128) 0 add_46[0][0] \n",
" conv2d_66[0][0] \n",
"____________________________________________________________________________________________________\n",
"block8_conv1 (Conv2D) (None, 52, 72, 64) 32832 add_47[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_14 (UpSampling2D) (None, 104, 144, 64) 0 block8_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block9_conv1 (Conv2D) (None, 104, 144, 64) 36928 up_sampling2d_14[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_67 (Conv2D) (None, 104, 144, 64) 36928 block9_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_48 (Add) (None, 104, 144, 64) 0 block9_conv1[0][0] \n",
" conv2d_67[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_68 (Conv2D) (None, 104, 144, 64) 36928 add_48[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_49 (Add) (None, 104, 144, 64) 0 add_48[0][0] \n",
" conv2d_68[0][0] \n",
"____________________________________________________________________________________________________\n",
"block10_conv1 (Conv2D) (None, 104, 144, 32) 8224 add_49[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_15 (UpSampling2D) (None, 208, 288, 32) 0 block10_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block11_conv1 (Conv2D) (None, 208, 288, 32) 9248 up_sampling2d_15[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_69 (Conv2D) (None, 208, 288, 32) 9248 block11_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_50 (Add) (None, 208, 288, 32) 0 block11_conv1[0][0] \n",
" conv2d_69[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_70 (Conv2D) (None, 208, 288, 3) 867 add_50[0][0] \n",
"____________________________________________________________________________________________________\n",
"lambda_5 (Lambda) (None, 208, 288, 3) 0 conv2d_70[0][0] \n",
"____________________________________________________________________________________________________\n",
"model_6 (Model) multiple 260160 lambda_5[0][0] \n",
"====================================================================================================\n",
"Total params: 4,511,715\n",
"Trainable params: 4,251,555\n",
"Non-trainable params: 260,160\n",
"____________________________________________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/balkhamissi/anaconda3/lib/python2.7/site-packages/ipykernel_launcher.py:8: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=Tensor(\"bl..., inputs=Tensor(\"in...)`\n",
" \n",
"/home/balkhamissi/anaconda3/lib/python2.7/site-packages/ipykernel_launcher.py:5: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[<tf.Tenso..., name=\"architect\", inputs=Tensor(\"in...)`\n",
" \"\"\"\n"
]
}
],
"source": [
"model = full_model()\n",
"optim = Adam(lr=1e-4,beta_1=0.9, beta_2=0.999, epsilon=1e-8)\n",
"#optim = SGD(lr=1e-4, decay=1e-3, momentum=0.7, nesterov=True)\n",
"#optim = Nadam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)\n",
"#model.compile(loss=[pixel_loss, feature_loss], loss_weights=[1, 0.01], optimizer=optim)\n",
"model.load_weights('weights_5_17')\n",
"\n",
"# Note to Self\n",
"# Pixel Loss wasn't decreasing, always ~151 while the feature loss was decreasing but very slowly\n",
"# Last loss achieved: 849 = 151 + 698\n",
"# While the loss at the beginning was: 899 = 151 + 748\n",
"\n",
"# sub_batch_size = 5\n",
"# for epoch in range(num_epochs):\n",
"# num_batches = num_images // batch_size\n",
"\n",
"# for batch in range(num_batches):\n",
"# X,Y,W = get_batch(batch, dataset='zubud')\n",
"# #loss = model.fit(X, X, verbose = True, shuffle=\"batch\", epochs = 1, batch_size=sub_batch_size)\n",
"# loss = model.train_on_batch(X, [Y, W])\n",
"# print(\"Loss in Epoch # \",epoch,\"| Batch #\", batch, \":\", loss)\n",
"\n",
"# model.save_weights(\"weights_6_%d\" % epoch)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/balkhamissi/anaconda3/lib/python2.7/site-packages/ipykernel_launcher.py:8: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=Tensor(\"bl..., inputs=Tensor(\"in...)`\n",
" \n",
"/home/balkhamissi/anaconda3/lib/python2.7/site-packages/ipykernel_launcher.py:5: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[<tf.Tenso..., name=\"architect\", inputs=Tensor(\"in...)`\n",
" \"\"\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"____________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"====================================================================================================\n",
"input_30 (InputLayer) (None, 120, 160, 1) 0 \n",
"____________________________________________________________________________________________________\n",
"conv2d_197 (Conv2D) (None, 120, 160, 32) 320 input_30[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_198 (Conv2D) (None, 120, 160, 32) 4128 conv2d_197[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_43 (MaxPooling2D) (None, 60, 80, 32) 0 conv2d_198[0][0] \n",
"____________________________________________________________________________________________________\n",
"block2_conv1 (Conv2D) (None, 60, 80, 64) 18496 max_pooling2d_43[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_199 (Conv2D) (None, 60, 80, 64) 16448 block2_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_44 (MaxPooling2D) (None, 30, 40, 64) 0 conv2d_199[0][0] \n",
"____________________________________________________________________________________________________\n",
"block3_conv1 (Conv2D) (None, 30, 40, 128) 73856 max_pooling2d_44[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_200 (Conv2D) (None, 30, 40, 128) 147584 block3_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"max_pooling2d_45 (MaxPooling2D) (None, 15, 20, 128) 0 conv2d_200[0][0] \n",
"____________________________________________________________________________________________________\n",
"block4_conv1 (Conv2D) (None, 15, 20, 256) 295168 max_pooling2d_45[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_201 (Conv2D) (None, 15, 20, 256) 590080 block4_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_141 (Add) (None, 15, 20, 256) 0 block4_conv1[0][0] \n",
" conv2d_201[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_202 (Conv2D) (None, 15, 20, 256) 590080 add_141[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_142 (Add) (None, 15, 20, 256) 0 add_141[0][0] \n",
" conv2d_202[0][0] \n",
"____________________________________________________________________________________________________\n",
"block5_conv1 (Conv2D) (None, 15, 20, 256) 590080 add_142[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_143 (Add) (None, 15, 20, 256) 0 add_142[0][0] \n",
" block5_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_203 (Conv2D) (None, 15, 20, 256) 590080 add_143[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_144 (Add) (None, 15, 20, 256) 0 add_143[0][0] \n",
" conv2d_203[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_204 (Conv2D) (None, 15, 20, 256) 590080 add_144[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_145 (Add) (None, 15, 20, 256) 0 add_144[0][0] \n",
" conv2d_204[0][0] \n",
"____________________________________________________________________________________________________\n",
"block6_conv1 (Conv2D) (None, 15, 20, 128) 131200 add_145[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_43 (UpSampling2D) (None, 30, 40, 128) 0 block6_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block7_conv1 (Conv2D) (None, 30, 40, 128) 147584 up_sampling2d_43[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_205 (Conv2D) (None, 30, 40, 128) 147584 block7_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_146 (Add) (None, 30, 40, 128) 0 block7_conv1[0][0] \n",
" conv2d_205[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_206 (Conv2D) (None, 30, 40, 128) 147584 add_146[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_147 (Add) (None, 30, 40, 128) 0 add_146[0][0] \n",
" conv2d_206[0][0] \n",
"____________________________________________________________________________________________________\n",
"block8_conv1 (Conv2D) (None, 30, 40, 64) 32832 add_147[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_44 (UpSampling2D) (None, 60, 80, 64) 0 block8_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block9_conv1 (Conv2D) (None, 60, 80, 64) 36928 up_sampling2d_44[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_207 (Conv2D) (None, 60, 80, 64) 36928 block9_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_148 (Add) (None, 60, 80, 64) 0 block9_conv1[0][0] \n",
" conv2d_207[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_208 (Conv2D) (None, 60, 80, 64) 36928 add_148[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_149 (Add) (None, 60, 80, 64) 0 add_148[0][0] \n",
" conv2d_208[0][0] \n",
"____________________________________________________________________________________________________\n",
"block10_conv1 (Conv2D) (None, 60, 80, 32) 8224 add_149[0][0] \n",
"____________________________________________________________________________________________________\n",
"up_sampling2d_45 (UpSampling2D) (None, 120, 160, 32) 0 block10_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"block11_conv1 (Conv2D) (None, 120, 160, 32) 9248 up_sampling2d_45[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_209 (Conv2D) (None, 120, 160, 32) 9248 block11_conv1[0][0] \n",
"____________________________________________________________________________________________________\n",
"add_150 (Add) (None, 120, 160, 32) 0 block11_conv1[0][0] \n",
" conv2d_209[0][0] \n",
"____________________________________________________________________________________________________\n",
"conv2d_210 (Conv2D) (None, 120, 160, 3) 867 add_150[0][0] \n",
"____________________________________________________________________________________________________\n",
"lambda_15 (Lambda) (None, 120, 160, 3) 0 conv2d_210[0][0] \n",
"____________________________________________________________________________________________________\n",
"model_21 (Model) multiple 260160 lambda_15[0][0] \n",
"====================================================================================================\n",
"Total params: 4,511,715\n",
"Trainable params: 4,251,555\n",
"Non-trainable params: 260,160\n",
"____________________________________________________________________________________________________\n",
"3015\n"
]
}
],
"source": [
"m = 120\n",
"n = 160\n",
"sketch_dim = (m,n)\n",
"img_dim = (m, n, 3)\n",
"num_images = 200000\n",
"model = full_model()\n",
"optim = Adam(lr=1e-4,beta_1=0.9, beta_2=0.999, epsilon=1e-8)\n",
"#optim = SGD(lr=1e-4, decay=1e-3, momentum=0.7, nesterov=True)\n",
"#optim = Nadam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)\n",
"model.compile(loss=[pixel_loss, feature_loss], loss_weights=[1, 0.01], optimizer=optim)\n",
"file_names = load_file_names(BUILDING_IMAGE_PATH)\n",
"print(len(file_names))\n",
"model.load_weights('weights_4_24')\n",
"sub_batch_size = 5\n",
"# for epoch in range(num_epochs):\n",
"# num_batches = num_images // batch_size\n",
"\n",
"# for batch in range(num_batches):\n",
"# X,Y,W = get_batch(batch, Building=True)\n",
"# #loss = model.fit(X, X, verbose = True, shuffle=\"batch\", epochs = 1, batch_size=sub_batch_size)\n",
"# #loss = model.train_on_batch([X], [Y,W])\n",
"# loss = model.train_on_batch(X, [Y, W])\n",
"# print(\"Loss in Epoch # \",epoch,\"| Batch #\", batch, \":\", loss)\n",
"\n",
"# model.save_weights(\"weights_4_%d\" % epoch)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X, T, _ = get_batch(600, Y = True, W = False, dataset='zubud')\n",
"Y, W = model.predict(X[:5])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAB5CAYAAAA3Q+qKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXmUHdd95/f53Vre3u/1hkY3CDRAEBRIUSRIkZRMSZQ8\n8oyj0diKHXlTYkVW7ETH8SQ54xkrcpQjjRNbZ2biM54TnfE41jiSxrHHi2LPyLIiU5Qpi9qoEQVu\nAEksxNLoRqP3flut95c/qur1A0yCIAVIENVfnibeq1d169a9Vd/7+31/v3tLVJVtbGMb29jGKxfm\nu12BbWxjG9vYxrXFNtFvYxvb2MYrHNtEv41tbGMbr3BsE/02trGNbbzCsU3029jGNrbxCsc20W9j\nG9vYxisc20T/AhCR94jIw9+B8zwkIj9/rc+zjctDRPaKiIqIm3//rIj81y+jnD0i0hER5+rXchvX\nAiJySkR+6Lt4/jkRecu1PMf3PdGLyBtF5CsisiEiqyLyZRG559ssU0XkpqtVx21sIX8o+zmZLorI\nx0WkfrXPo6pvU9VPXGF9BiShqmdUta6q6dWu0/cqROSnReTrItIVkQv5518UEflu1+1yyAf7Tv4X\ni0g09P3fvMwyf19EPnyVq/qi+L4mehEZAf4C+D+BMWAX8E+B8LtZr228KH5EVevAXcDdwAeHf5QM\n39f39vUCEfll4F8B/wLYCUwB7wPeAPgvcMx14Q3lg309v9f+H+CfF99V9X2X7l94g9cjvt8fhpsB\nVPUPVTVV1b6q/pWqPn7pjiLyL0TkYRFp5t/fKyJHRWRNRD4nIrP59r/JD3ksH/l/Kt/+DhE5LCKb\nInJCRP6zoeJnc0+iLSJ/JSIT1/ayXxlQ1XPAZ4Hbcgns10Xky0APuFFEmiLyb0VkQUTOicj/XpCI\niDgi8n+IyLKInATePlz2pZKaiPxC3t9tETkiIneJyL8D9gCfzvv6V55HApoRkf+Ye4vHReQXhsr8\nsIj8sYh8Mi/3KRG5+5o33HcI+bPya8AvquqfqmpbM3xLVf9LVQ3z/T4uIr8tIn8pIl3gB/O++6SI\nLInIaRH5YDF45+32+0PnubTNHxKR/+2FnikR+dm8zBUR+V++jev7odyj+1UROQ/8roj8vIg8NLSP\nm9dtr4j8IvBTwK/m98ufDRV3l4g8IZmy8IciUnq59Xo+fL8T/bNAKiKfEJG3icjopTuIiBGR3wVu\nB/6eqm6IyDuAXwV+HJgEvgT8IYCq3p8fekc+8v+RiNwLfBL4J0ALuB84NXSadwE/B+wgs3L+8dW/\n1FceRGQ38PeBb+Wbfhb4b4EGcBr4OJAANwF3An8PKMj7F4B/kG+/G3jnZc7zE8CHgXcDI8CPAiuq\n+rPAGXIPQ1X/+fMc/u+BOWAmP8dviMjfGfr9R/N9WsB/BD56pdf/PYAfAErAf7iCfd8F/DpZ3z1M\n5mU3gRuBN5O1/c+9hHM/7zMlIrcCv012r8wA48ANL6HcS3EDUCcb8H/xcjuq6r8G/gj4jfx++bGh\nn38S+Ltk1/vavH5XDd/XRK+qm8AbAQV+F1jKra+pfBePjMDHyB7mXr79fcBHVPWoqibAbwCHCqv+\nefDfAL+nqg+oqlXVc6r69NDv/7eqPquqfeCPgUNX9UJfefhzEVknI4QvkrU/wMdV9am8T8bIBoH/\nSVW7qnoB+JfAT+f7/iTwW6p6VlVXgY9c5nw/T+a2fyO3SI+r6ukXq2Q+EL0BeL+qBqp6GPgYGWkV\neFhV/zLX9P8dcMcVtsH3AiaA5bw/AJAsHrYuWZzl/qF9/4OqfllVLRCT9dMHci/gFPCbvDTye6Fn\n6p3AX6jq3+Qexf8K2Jd9hZkh8WFVjfJzvVz8lqqeV9UVMjn5qnLAdaspfaegqkeB9wCIyEHg94Hf\nAj5HZgneAdyrqtHQYbPAvxKR3xzaJmQa//MRwG7gLy9TjfNDn3tkFsI2Xhj/uap+fniDZHG9s0Ob\nZskG6gXZivmZoX1mLtn/csS9GzjxMuo5A6yqavuS8wzLM5f2fVlE3GFy/B7GCjAxfD2qeh9kmSZc\nbGgO98UEWd8N98lpsufrSvFCz9RF/a6qXRFZeQnlXorFS7jh5eLS+o5dhTIH+L626C9FbmV/HLgt\n33SUzP37rIi8amjXs8B/p6qtob+Kqn7lBYo+C+y/VvXexgDDS7GeJQuqTwz10Yiqvjr/fYGMwAvs\nuUy5l+u/yy3/Og+MiUjjkvOcu8wxryR8lawP3nEF+w634zKZVT/sIQ+3WxeoDv228yXU6aJ+F5Eq\nmXzzcnFp/79Y3b4rywV/XxO9iBwUkV8WkRvy77uBnwG+Vuyjqn9Ipsd/XkSKh/3fAB8QkVfnxzVz\nHbfAIpnWVuDfAj8nIm/NNf9dufewjWsEVV0A/gr4TREZydt9v4i8Od/lj4H/QURuyGMz//NlivsY\n8I9F5LWS4aYhme7Svh6uw1ngK8BHRKQsIreTyXi//3z7v9KgqutkWWz/WkTeKSKNvB8OAbXLHJeS\n9c+v58fMAv+IrXY7DNwv2ZyFJvCBl1CtPwX+gWRp1T5ZsPhq8uBjwO0i8hoRqQAfuuT3F7xfriW+\nr4keaAOvA76eR/u/BjwJ/PLwTnk+9a8BXxCRvar6Z8A/A/69iGzmx7xt6JAPA5/ItcifVNVHyDyD\nfwlskOnKL6Tnb+Pq4d1kgbgjwBrZQz6d//a7ZPLcY8CjwP/7QoWo6p+QBQr/gOye+XO2XOuPAB/M\n+/r5gug/A+wls+7/DPjQpbLTKxl5gPofAb9CRnKLwO8A7ycbBF8I/5DMOj5JFov5A+D38jIfIAtq\nPg58k0zTvtL6PAX893l5C2T3xdxLuaYXKf8IWczoIeAZ4G8u2eVjwB2SZev96dU674tBtl88so1t\nbGMbr2x8v1v029jGNrbxisc20W9jG9vYxisc20S/jW1sYxuvcGwT/Ta2sY1tvMKxTfTb2MY2tvEK\nx3UxM/b48eN67tw53vzmN/PFL36R17/+9ZRKV76mTxzHfOYzn+Gmm27itttuQ1U5duwYURRxyy23\ncPToUVqtFseOHePQoUOMjv6tJW2uOh544AEmJyfZv38/jUbjxQ+4Cjh16hTT09Mvqe2eB1dt6dg/\n/eqfq2f3EiQXWFvYxJvciVtqQWKIrUUBg6AoaiHFAkLFMzhxSt9NWFzrkoRCtVzCV2Uj6RGsrbF7\nxzgjE6Ok6mFTSxRHIILxhXK5glhQK0TWIvl/AGoERxxcwDjgFNsBVSgm0SqCqqJYRAQj0E8j5k9e\nwG34VBKoNMOsrbWCSYRYs/KixBKhhIlFLFhRjJJdZ5HlNvgOVrN/lbwSeTeoKAaIbIpRB8dxsKpY\nFEmLmTeK2mwbVlERQLHWZtdQgl2ddd737r9/VZcE/ie//UXtbi4xuXs/F848zeTENKbWApHBDaRS\nXMngklBAFNIkZP6przI2NkN9z6sQTemdO0aURtRnbqa71sbzoL+6QH16D151JO/Dov0MiA62CIJI\n3nyiGBUUyU6meSVk6z4QFCuCaNEZyvKxJxmRTfzpA2h1YlB/B5PdGMJW/xXXNpS0WHxUHf62tbvN\nz4tu/brZ3qBareEaJ7sHtwoYtKFqdvxWhqTm/wmiyq+96+4X7dvrgugbjQb3338/nU6H2dmXll7e\n6XR45plneNOb3sT4+PhgmzGGVqvFZz7zGZaXl3nve9/LDTfcQL/f59SpU8zNzVGpVNizZw+Tk5NX\n/Zr279/Pvn37WF5e5qMf/SiTk5Pceeed3HTTTZTL5W+XjJ8XzWYT170uuhSAjbWdrDo+fjLBqeVN\ndrdm6Zd8nNTDkGBFUGuyhUbUYq2iRqhEylIS0j99npGpKcr1Oht9IQj7SNiG0UmeCmJ6j25yy6GD\neI5DLMr5dgDrlpFWCa/k4ajguOQPfPZQigjGZAOMw+D5LZZQYIs+Bavg2OyBslYJSeiNNRhplNA0\n4rOf+xKmVmH2VTdQrY/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdcec09cbd0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"i = 3\n",
"x = X[i].reshape(m,n)\n",
"y = Y[i]\n",
"sub_plot(x, T[i], y)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_4\n",
"block1_conv1\n",
"block1_conv2\n",
"block1_pool\n",
"block2_conv1\n",
"block2_conv2\n",
"block2_pool\n",
"block3_conv1\n",
"block3_conv2\n",
"block3_conv3\n",
"block3_pool\n",
"block4_conv1\n",
"block4_conv2\n",
"block4_conv3\n",
"block4_pool\n",
"block5_conv1\n",
"block5_conv2\n",
"block5_conv3\n",
"block5_pool\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sketch = cv.imread('buildiing_10.jpg', 0)\n",
"sketch = imresize(sketch, sketch_dim)\n",
"sketch = sketch / 255.\n",
"sketch = sketch.reshape(1,m,n,1)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAD8CAYAAADdVNcyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvVeQJEl65/dzj0idlaVVl2g13T1a78zOzC6wwAILcWcE\nAULSeDgjabf3QprRjA+E3RNp98KHO9JoRiONS2V3diTveHYAFjgCewB2sbtYNbqne6a1rq7u0jpl\nRPjHh4jMjIiMzMpqsdsA6rPuyohw9889XPz971+4UCLCoRzKoRzKoTwe0T/pBBzKoRzKofxNlkOQ\nPZRDOZRDeYxyCLKHciiHciiPUQ5B9lAO5VAO5THKIcgeyqEcyqE8RjkE2UM5lEM5lMcojw1klVK/\nqJS6rJS6ppT6vccVz6EcyqEcypMs6nHMk1VKWcAV4OeBu8D7wO+IyIVHHtmhHMqhHMoTLI+Lyb4B\nXBORGyLSAP4l8CuPKa5DOZRDOZQnVuzHpHcGWAjd3wXe7OZ5ZHRE5mdnETQgKMD/E5aOB32K6uv2\nsa97EwEVilz6i7XpQ3XxrRL8EkQlEYfgTnXm48Fytg/fD1BUie/xqGW/DDyUZPmxFsgDRCkdFweP\nOyGohP42258ohRLDwt27bG5s9VV7HhfI7itKqa8CXwWYmZnlj775fbYlTVoMWRFEKUyAFE1cCAOi\nCh6q6J8QhqgOsFYhgFHNXOuSTdKMu/2ksxwk6VKa/9px0cbYqHUmFEqidUUSULXDtCOAkpDO4IUE\ngv4KFdIVzg+Fnx+C8t0JDWvC+a1US13khWIZl9gxNvX4BZbk1B/INcspsV0ICT+H8pikVxtIfNwn\nAIartl+tO9ubBGQhqPYhre2byD1+m1G09YtI0A5bDwIOEgpjBDQYEbSRFt7URNC2hV01/M6v/VTP\n9wnL4wLZRWAudD8bPGuJiHwN+BrAiy+9KmU3hc5pyljgCKLAqBZUJhEw/3ngGG+0KuLedmyCayJw\nh8JIcNF2Vx3tvNX2A31GQAUlbcLgLslVLF7/kuqjxBzDzNbEGKtfoUIgSzNxvl8AKwBNwS98UUGH\n0gwS0tkVOCUM0u2MiOdfM4/jfVnSICX+rFnnVZDG5m+HBA8T3ZpeQmnqJoeENln2G0kldXy+33il\nVh3+EtV1Y5QSg8+Y+pZr/HnQJqXJP0Jg2woSEJVmCk0LZEGLIAYsS1EFjK0p7tWRXhUuJo8LZN8H\nTimljuOD628D/2E3zwrBUlATRdbg9x4aRBRK+T2Rimd+CICRkIkhwtgCHy1QlTYghBpx4BICiRAC\nh3rPsP9wmGZ56WY6Y2lV4QoSorORxh+uHKFrTbgiSSQfrBj7bcKnEoUIGMAKKokdAK2lwQRdhh+3\nQgdatLTTo0N6I51YKP5m5VWhFhgB03gDi9XLcJl1laSMJ94YJTEvH5/8WIwbT7ZIqGPbj812c0so\nrCQQDdc9CZV1+1nTX3tU19YTYr6t9hIa/YnfHpB2fbVFMMZ/P238NqdcwdIa7RlS2WYq+pPHArIi\n4iql/jPg3wEW8H+IyGc9QpB1Pby0QhtIixeARJumqTiVjTd+WgaDlnvEdBAqmVbjDwF1lG35TBoT\nBeGWtqD3a8VDqKCDCxPSFRnaSAxXQgUfVJNE5ttROYgCfHN4pWizXABlQpELKKNanpr5JVYzEh15\n3zbQtkcTMWxrpUQ1qUJCB68SKmQHaLfob1x38n1iFY83sG4eoil4APlbCqyE6jhRhtnNb1NUpJ53\n8ZuAsGEgVwimlQYh5juB7Ybuw/7DCCsEpjpp+dECHgEpEYMyYLQiLRrHKGzTSKzT3eSx2WRF5E+A\nP+k7gFagAvtgwMQI212lswGr+INWI+0yToWOBxK7VnGvsXapYnF0gEWs0jXNDGHl0kbWhHQlFJ+E\nXOM9fKBciZ93YSBvAr2ICvokhRhA+/burIBY0vJnqWaFU20GnQh67Zxv247btt1Q3xbLEtWhrifU\n9WC5qguX6F31D40CB5F4l9Sq4uGbWM/brubxdhDW1BlRM4w0/SQ1AlEdZr5mhP5wX7VssL5TMFZr\n1k9pag9qj7RHlSrygoFOBaB9AiMKlA5MZAeblPUT+/AVFyGwgUDw0YuWrTBeLE1WG8fUxGGpij3o\n1s6S7AYE2N7M9Fh6W7gf7nEDj+EKKYHBJwxIkThCFxJ7rgREN5lAKFzsw0BTdzO9AqgAUC0lOMqv\nWF5KYbTfa++gSIlqsQJL+aYDE3q3ZkWNGMXbdbM5oIj1LR29YTesjmZ357Cka3kl2sTCnVg4sh5g\nnSyd6e+vV4jJE0F4H00ikjhBqO+MkogmqMX8J2qLtK3OzA0b6pokIl5nWm0x+GDeChsrf1HR534d\n99tFuG0pxGeyKjCtKWnFIa2xYv8V4YkBWVTwkqGuKgIYcXNBiKw2GVKI/HZIuGDC1x2Nv4PKRlw7\nnUJpjTjFgLqlJRwuxkqbLLQJrq0K1KwgQstM0fSvBTzlA6pW4GqFVv4zbKEhIEawAM8YChakHIMR\ng60Utp1iV/n2Ag+FZRRYTWYQfIkNJTqcT5H8bNU7lTyASMq30Hsk+Y+XWaKeCJOJNdawATsh37tW\nlj7S/djCPHJ5wEQkZH4I7hKjiWRrvFAjOkKKw+0kgTq3QVNaM2Ui3sLgGQQKg6uEakWY40pQaZrA\n6be5UMAQeEtAMiIk5wDy5IAsIeBr4i0+cPjlFS+BTjzswMcubSyxvfWTsGY8CXqboJjUz7U+9Mfj\njSNIGBua7xN5L+lQIkDKCxIVfHkzyuAYUBiyHmhlUJ7Bq9XY24Cl1atcubSCRvHalz7H+PgoAJ62\naXiQQaGVoEzTzNDuAMLDNQXRoWPrZVqFFn/aLUsTZT8/4RFAtNMM8Z/Y6KCLhj5S87dMOrJB9s2Z\nlrvEfnvpiBGN6HPp0KGC3jRSNwLyocLhwjpDz5rXKuIr0BMiMO3/0sH/DipPFMhGpdlciPy2XUKZ\nqGJF1GzjIQYZNukmMd6EThRU+KOVihSKkFiHOp4JwVS0cGPvUrEkFjKCX0JrGBYO7WmwlVDBp7O2\nMWSArHGo7dUo16ss319lcW2N5VsLbFcrfHzlKsuXr+Pli/yG1+Dv/Nw7AMxOTlC1Vcs0oXSbXTQr\nWQtUA/bYmt8bBrp4L9NMcawM4hIP0pPJxgI12X9H/ndJR7c4/zZLyCIa3AfSpf42L5P8h/taIPIt\nISKxTjquI/rhK4HJQsdIsXM6eRxOk5+3PyLHdMbiO2ileUJAVkVYoC/+m0bmbLZ9R6XZwkItrjUf\nFlqtz9evOgo2koxmRgeF1eoJVWcv3AVLOlisP10qsOuErlv+g/Ro8QG5OaUq/Kw5XcuI/5nKDVhr\nQwRHGdIGsuJR2atzd3GZKzfuc+PGPdYWL3NjdQsZ8RhOz/Dq63PMuyfYWlmmvLnJZ9/6EdPjBQD2\nnn2BiZFR8pkMriekUxY6zApDdLHNslXEzXTJk+b7hPM54Vtm13zsxmRbekKdQWxg0NlQ4h46tfZI\n1Y9bunYvj1ziTFTFMjA+jbJru2xWX+nMyV5MNqqj7RCOJzLKa4UJTetqEZGwNTfa4DuLv7ngoO2h\nXdUD84OWaBoOIIdbHR7KoRzKoTxGeUKYbJSG+0PTpDmwXcKGAieFUQluiTq79FDtDk7RdX5chw0j\n9Lw1jqH1os1eVyGt1WH+tKmYgV8p/yMWQj1Y6udrEDICJTyccp3yzibvX7nNB+9d58LlD8jkSqTG\nB3nqzAl+6c0xjj53lKzncGRkmBtn7lO7t825nQ9QVpoP3v0QgI8/OMsXP/8273zhLVIpG4OgDWgr\najZIysEH4Vv98MVug45W+ISBSXxqX3v0GhpBxMZ9TwJ3TX7PJyFlTYmmMMm808sd2h9F4xPwwn7D\nRotwe2u34tioUtp+JeSjufqyqa89l9wfprbbc5Syh22wGv/7hBGCKWD7DL8S5IkBWV/aq5DC797K\n3K4vF2vsLXQNDRTaLa2
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdd186ca910>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"result, _ = model.predict(sketch)\n",
"imshow(result[0])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAACeCAYAAAA8AsGwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd8HMXd/9+ze1XSqRdLlixZ7sY2BleM6ZgAxkBoT2Ie\nAiSQEEISeCA8gSehhJA8IQQSUggEfhAeEjuFEMBgXACbYrCxcQd3yUVWsbp0fXfn98fd7u2dTrJI\naDb38et8u7MzszOj28985zPfmRVSSjLIIIMMMjh6oXzaBcgggwwyyODjRYboM8gggwyOcmSIPoMM\nMsjgKEeG6DPIIIMMjnJkiD6DDDLI4ChHhugzyCCDDI5yZIj+MwAhxFVCiDc/gfusEEJc83HfJ4MM\nMvhsIUP0nyCEELOFEKuEEF1CiHYhxFtCiGn/Zp5SCDHyoypjBhl83BBC1MR/t474+WIhxJX/Qj7D\nhBC9Qgj1oy/l0YUM0X9CEELkAouAXwOFwFDgbiD8aZYrgwz6gxCiXggRjJNpsxDiSSFEzkd9Hynl\nOVLKPw6yPGfa0u2TUuZIKfWPukxHGzJE/8lhNICUcoGUUpdSBqWUS6WUm1IjCiF+LoR4UwiRFz//\nqhDiAyFEhxBiiRCiOh7+ejzJxvjD+B/x8AuEEBuEEN1CiN1CiLNt2VfHRxI9QoilQojij7faGRzh\nmCelzAGOB6YCP7BfFDFkeOQzjswf6JPDDkAXQvxRCHGOEKIgNYIQQhFC/AGYBJwlpewSQlwA3A5c\nBJQAbwALAKSUJ8eTHhu3bP4ihJgOPAV8D8gHTgbqbbeZD1wNlAIu4JaPvqoZHG2QUjYAi4EJ8bme\ne4UQbwEBoFYIkSeEeFwI0SiEaBBC/NiUVIQQqhDifiFEqxBiDzDXnnfq3JEQ4tq4YdMjhHhfCHG8\nEOL/gGHAC3Gj5tY0ElCFEOL5uCy6SwhxrS3Pu4QQfxVCPBXPd6sQYurH3nCfEWSI/hOClLIbmA1I\n4A/AofiPsiwexUmMwAuJWVGBePh1wE+llB9IKTXgJ8Bk06pPg68B/09KuUxKaUgpG6SU22zXn5BS\n7pBSBoG/ApM/0opmcFRCCFEFnAusjwddAXwd8AF7gScBDRgJHAecBZjkfS1wXjx8KnDJAPe5FLgL\n+AqQC5wPtEkprwD2ER9hSCnvS5N8IXAAqIjf4ydCiNNt18+Px8kHngd+M9j6H+nIEP0niDhZXyWl\nrAQmEPtB/jJ+eSRwAXC3lDJiS1YN/EoI0SmE6ATaAUFM40+HKmD3AMVosh0HgI9cc83gqMI/47+7\nN4GVxAwNgCellFvjxkchsU7gRimlX0rZAjwIfCke9zLgl1LK/VLKduCnA9zvGuA+KeW7MoZdUsq9\nhytkvCM6EfhvKWVISrkBeIxYh2HiTSnlS3FN//+AYwfZBkc8HJ92AT6vkFJuE0I8CXwDWAJ8APwW\nWCyEOF1KuT0edT9wr5TyT4PMej8w4qMubwafW1wopVxuDxBCQOx3ZqKa2Ii0MX4NYkakGaciJf5A\nxH04Q6U/VADtUsqelPvY5ZlUI8cjhHDEO6ujGhmL/hOCEGKsEOJmIURl/LwK+DLwjhlHSrmAmB6/\nXAhhkvXvgduEEMfE0+XFh7cmmoFa2/njwNVCiDPimv9QIcTYj69mGXxOYd/ffD8x77FiKWV+/JMr\npTwmfr2RGIGbGDZAvgMZKgPtqX4QKBRC+FLu0zBAms8NMkT/yaEHmAGsFkL4iRH8FuBme6S4m9mP\ngFeFEDVSymeBnwELhRDd8TTn2JLcBfwxLu1cJqVcQ2yy9UGgi9hwuz89P4MM/m1IKRuBpcAvhBC5\ncQNjhBDilHiUvwLfEUJUxp0Qvj9Ado8BtwghpsQ9ekba5qNSjRp7GfYDq4CfCiE8QohJxOarnv4I\nqnjkQ0qZ+WQ+mU/m0+dDzFvrzDThK4BrUsLygIeJTYZ2EZu0/VL8moOY4dEG1AHfImadO9LlR8wB\nYTvQS8ywOS4efgGxCdlOYt5iNSn5VBJbq9JOTP65zpbnXcDTtvOktEf7R8QrnUEGGWSQwVGKjHST\nQQYZZHCUI0P0GWSQQQZHOTJEn0EGGWRwlCND9BlkkEEGRzk+KwumMjPCGXzcEIePkkEGRyc+K0Sf\nQQZHJfY19cgerwdHFHQFBAJz8ajAWmUKwjyPnSSO4+HErCGzt5LxD9IMlxiAkAmryX4d61haYTKe\nYcz9TsTG9zKWjypE/DtWRuv+ZgFE7F5mgBVs605Tw9I6+Mm+Vl66aJ+lXrpP+aT9K7lC9rj2+ksp\nE38nM0k8TEr7sZlrLJKBBEOiCEGvQ6G6u53KEeWHbZ4M0WeQwccIlwyjGC48uo4u46QoUr8BqwOI\nfccPMUkfEkQviRO6kAgJRpx0Y0Qvk+JJi4RkEiFJKWNx4pEUwDBAGLFIJvELBFJRECKWLqb1iiTm\nFYlK2M5tUewVsEGmnMi0Vz9LFJ8M6++QruSJ//qEJTrg+N/KJH37t0nwsT9QrDOQRiy+YSARuKSK\nJkKDKmuG6DPI4GOEFAoGAomCkWQZ28x1Oysmhce+7ZY8EuL9RSy/VAue2HWL4IWZRiRHEgkSQYqE\ndQ/oAnQ11oGoUsFjJhNgkGzBS9OyJ94l2XjZTNOH8M3rdn4Uts5IkKhkEmxtk9Qo9kzThH1oDD4T\nKUDa/0gyUW97dlaYFFbHIBHWkWG7r7R9y6QzBQOJghL72wkVJW079UVmMvZfgJQSwzAwDIOenh5e\neOEFotEomcVnGaSDEHH+jn/sxwnzvZ+02C+LJK6z5B1bvmZ2wpYwiQvMuDF6QjEEigoRp4FfMdDQ\nUPQQoZZOlv/9Rdrbmok4QCopI434vRTbfU1YBD9w1ZLS2IuXdJDaTiI1Yj8N9m99xMCfdDfvE3QY\nLhDxOFYvKeNZyHgZ4iMrW5sLIa3bi6Q/9OGRsegHATuBa5rGq6++yo4dO8jLy2POnDnk5uZyyy23\ncOedd1JQUJCw1jL43EPYpBe7FZr4haSapsmShWlpm1KNGT3JGJa2FELE7EQzqgCMmHWuIBFGLJ4u\nJZoKEUXHZ+i0bNvHjj2NrHnjNXbrbZwxaybP/m0Rby57g+vvvpGa8gqiusTliFv0FonHOgCDFM4Z\npPJiGudJAWnRnwn/r+BfNMg+kiJ8uHtbun+y2p88UhoEPitbIHwmCtEfIpEIS5YsYdGiRVxxxRXM\nmDHDInNN0/j73/9OcXEx2dnZLF26lHvuuedTLnEGafCp9L5NzZ2yw5NNVthAU00jzGajJ03GmtRJ\n4kE2L8fll9jAX8R0XZH4tgvw1sMk45N+EpQ4wUdVSQCdPN2gvbWVFxet5pWV/+DkOWczdfIIhlUM\npb2zF607yvLX3+T97fUYbe3cct8PGV5ThWGoKIqwymdq90lIIcTDNXyKnP3ZJoNUpGr0prSWMgFh\nr2PSxCsgDBlPIxOSW4pebx4bgDBiAlrAoVIabKNq+LDD/rYzFn0a2DY+orOzk/vvv59AIMC9995L\nfn4+QggikQg333wzEyZM4MILL+T+++/nsssu45hjjjlM7hl8nmAfYadq2NZkrHlO4iD5ybUFSJHI\ny7wBtnOLTWKMr0iIqKCpEhWJ0dPDy8+9yd9eeJqK4ZOYe+mJfP/4m9myt5nW+ggvLvsLb722Al3P\n4qGf3Up5kY+NG95HC3cT9gdQ3Dk4lLguTTKnW8cD6jVmp5UoqqDvua0myW0zSHySnUVSeW0jttTy\npytTn84yKWKiQazBnCnbfMjhRcaiT4PGxkZ+9atfoWkaV155JePGjUMIwbJly/jFL37B/fffz+7d\nu5k7dy6LFy9m7tzYKzB//etf09rayr333puRbz57+FT+IM0tnbLDnU1WxECLW8LWxJid6OP/JVnx\n9jhxcrR3CvZJWLs8JKREFxBEokmD3FCEzq5ufvWHp2nvDTHvqrOoDLj5zaPP4M41OLh7IyOPOxmt\n9QBf/NIFVA4ZyrtvrCL
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdd1801b450>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure()\n",
"a = fig.add_subplot(1,2,1)\n",
"imgplot = plt.imshow(sketch[0].reshape(m,n), cmap='gray')\n",
"a.set_title('Sketch')\n",
"plt.axis(\"off\")\n",
"a = fig.add_subplot(1,2,2)\n",
"imgplot = plt.imshow(result[0])\n",
"a.set_title('Prediction')\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAD8CAYAAADzEfagAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvfmPJVd2JvbdiHjx9sx8uWfWzn1rNrub7H3RaJfGtgz/\nYNiGjTFgoH8yMIZtWBr/BQIMGIZ/FGwDAmzAMDwGRpqRpWm1lp5uimQ3m2QXySJZW9aalfvy9iXi\n+ofznRvLyyKrms1WjRUHKLx6L2O5cSPi3u+e853vGGstCiussMIKezDz/r4bUFhhhRX2b5MVg2Zh\nhRVW2ENYMWgWVlhhhT2EFYNmYYUVVthDWDFoFlZYYYU9hBWDZmGFFVbYQ1gxaBZWWGGFPYR9ZoOm\nMea3jTEfGmOuGGP+4LM6T2GFFVbYL9PMZ0FuN8b4AD4C8BsAbgP4MYD/2Fr7/i/8ZIUVVlhhv0QL\nPqPjfhnAFWvtNQAwxvyfAH4PwImDZmu2YU+tLMCY5Ddz0oafud3nrA/TGKsfFjohxfy0cSzfY/ke\nxTEmkwgAcNQdAACGo4gH8tzhYu5nbZw5h/sPO84zHozny/89Ntp4mW1gpW1yvNxxUl8tsn/TbQ2/\nG8/jOY3bVuZKoFwpsw36vQovCOXaI7m+Yb8LABiP5LqjycS1y3W3/ocniN3l21TjtcO1j7Shgetn\na+WcQRCw7T70orRbDPsp8OWzVpFtm7USAKDkm1Rf8L66+4LM75nfeFw9jydPRubyjN6/VJ+637V9\n2atF9qHMA5/8A/tzAiN3Yz/ltp8Kl306UPcge+t9e+/y7V1r7dInbf9ZDZqnANxKfb8N4CvpDYwx\n3wXwXQBYX57HP/+ffx9hGKT+njvi1M2wJ/yoD2Pu9+R5T72QJz1oHz9opg5z3410gIyjCUaTCQBg\nMBwCAPo9+ezxe7vbx87+MQDgz/9O5pMrt2UwsaYCABjHMUaDnvx/JPtFHHj0Zge+9Fu5UkOl2gAA\nhJUqAMAPZQADt4ksEEXJoH3S8Yy1iDkIxRG3iSP+Tb5Xa9K+crmCCbcpleXcF558GgBQb9YBAI89\n8xKaC2cAAN2jQwDA1YtvAADuXr8EADg62kMUj6WpbvDgucfSrsEIbK9FNJEvlvtg3OM1yPe4NI9O\nl9c1knO2FucBAJXKDLgzAk4uYUkG9blZafOXnlsBAHznRXmH1lol2EiOPR7L52A44nc5z2g8caPl\nOJbjRn6D1yS/V70hbDzh9ck2pZLcm0a1JttUy+53nfymBk2bHk3zk979B0173wF2+sl284873ElP\nf24CfpBBc2rST/47dQY7vU3S5OzW9oQD2KkXP72v/Dgcyf149nf+6xsntH7KPqtB8xPNWvtHAP4I\nAF58+rytVEKEoe/+PoU2cl8zQ+bUCMufTxzhcp2YOdHHT6kGFhxv3CAyHskLNBpLxw9H8iKNhiM3\nKHlEG42qDGQETLh4bRPvXLkLALi9K4Nlpycv/2TSZXNtgtjK8lKVyhV+8uUK5btfCgEjtzS2cs6J\nPtSxol75J4fOdQIHxGhi3d88T87hh4qYxjynDDLG8+EZ+f/amfMAgNZ8CwCwevqsfJ55DMaXttcq\nNXbgF2R/DnL+LQ9H7X1pow7Q7LeSF2WaORxZwEqfRIriS3Jcz8p9MPEQlVBQYn8sxxn2pE/LFRkY\ng1IpuWYOvgdHRwCAN34m9+H4WAbc58/P4NyS9EXDTe6yb+BLwwKvBN+Xdo1i+dydNOX6Atl2pdVE\nrSx/CwL5dJNEboCU5zqLSqcGTUwPhNPDoHVAYnrQzFlq5EkGwvuAEZy07QntSK2+TtpX/nu/gd66\nKzhxv9wu93c36vVPH8O7z/hxP/usAkF3AJxJfT/N3worrLDC/q22zwpp/hjAk8aYC5DB8j8C8J98\n7B4mPUl9nC8n9fsnTRBu59SsnJtVMufM+ax02arL7NFohH5fkKQiSkVwJS7vwpKgm9nmDMIy/WH0\npSmiKB13AACd7hDvvCcrgqMOEVcgaKbKpVoprLoltvoF4cnxbMrvCQBjJDBSryHxv0W8phgRl9Pq\nIjUOyQoCq7Zm0WzJsnRufhFAgnzv3nhP9uVSNZrEaC3ItqtrywCA+QVBmmcvyDK92phz98ASpfmn\nz8u5HTKcwN6RBh1xCa9tdv1WUqemASz7IOfvNJ70kRdNUA2IkCK5D6OBuEIGXTletdZ0/sRInxEe\ncDiQAx4cCPrduD2Dl59dBQC8/KT0ycocl9MVOX4YhihzqT2B9Kk9lONWZRO0Zn0EClXyy9Q8WjM2\ntRrI/dFkn9zMYRwy1EvyprfNWeb9yr1702Yc+lQEeHIrcz7q+ywbPx778vjppaVNHz3XrFw78lud\ndE3mIZHmZzJoWmsnxpj/EsBfAPAB/G/W2vc+cceTGm+yNyX9gCQ36GSHShKgSB0OukyVzwnX2+NJ\nhD59VMOhBCdGXHrrwOMbDxUuS1uzswCASlm+l7gU1BfcpFb7iVtGfpidET/Xs0+cx9kP9gAAu20Z\nICYK/n1ZckfWYMIBYcL9vbzfhn+PETkfZBzlHmoOjGG5gWpdfHrN2QUAwNySDHYLS2tybUvLmJmT\ngS+eiB/14pt/Lc3iUnTM44eVOtbPnOZ1yaB75vwTAICZOTmuF4TwjQ7eDLiwn/xzj8vv4yEGnIg6\nXTnnoCdLZVPittynXI5hdVBif0UMqIGDMEwMfQLCMv3N7Mju8Ta3tQhDaTO4rDZWnzeT6dvN3Qle\nfZft8WTJ/a2XZPBcanLwDP3UICz7LTa5dPfpf/bs1IvrbmceKNjpAedjQzzuJcmORia1zJ86gs0d\nN+3Tyg2MJ1neRZa8nzZ59tx16fGye38sAHITQfqk+UlBj2Ozrt6TjnNi4x/OPjOfprX2zwD82Wd1\n/MIKK6ywvw/7ewsETVsKmplpcD01KXwMpHZLU+49jqJUxFM+h4xETyZKG/Hg+YIWyyVZDteInBya\nLPnwfaXw6Ed2uXQy/M+266gniGpjpwOvIojVJ/VmFMkt8UIJGk36fRd00t6IHHVJ0RVpRkGIMqPm\n1YYcd7YlaGh2XpbQrcVlt+Ruzs4BAOpNQb7VKgNMYQlghPf9i68DALZuX5NzEtER0GJtdQ0L83Kc\n5bVTAICF1XMAgIAui8D3EpSRY0B5vvTxqQtPYX9PUPe1y5cBAN22LNOrdbkmpQx5JoJxaEOR9Tjz\naQDA45qYwbGQfR1Fcm2TSYQgTII4AOB56koJMn1sjY+9Y3lm/vqt2wCAzf0+AOArz0uk/fOPz2Nl\nrsLr4opDEax7pqOpCKVDtTaP7NIIUX/KBV6Sq53+ZtLw6j7L/By8zZ5Rl8YnBXA+Aaal0e1J6+iT\n/3DCNulz3m+3dIvvv3QHcl3yyX6IE61IoyyssMIKewh7tJDmAzhkE/+PTUjiRA5K++kPlCie8Oh0\nhnGcRqLHpvINK2X4JSVlJz6SzGeKbJynNugMd/IcKt9GhGfvb+wAAN6+vI/jvvw2igXpzCwIaplM\nBDFNxhOM1akJJYuLD63eoG+SaHJheR3zS7L/7JxwEpv0vdYaiprLCNX/SjSkfkr97hlg6+49AMDV\nSz8FAPQ7EkQZ0e9bqYlfb2lxAbOz0o6VU48BAIKyUnqkvSXPuIBZ5Hyu9B/35V6NJyM0WuJjrdTl\n+nbvCbIbjzqZ6/UDHx4Rpm+U1iU9FBNVGuMnQCIQpFpSfqb6tcdDGE8J73wV6De1nuIJRS7Godgj\n+p9/dlnOfXtX6Ekf3erg6y8Ion/mnKDvvlwCykGStJA8IzmEeOLjn/VP5v2EGYrQfYBTPviZPt40\n2DoprnBSu/Q496P/mPviyTx1SSiTnxQOMinOZR5ZT0c3pi4wR5+SYOJ9GvgJViDNwgorrLCHsEcL\nabpZIfGHxFapMqmsCwj9RzMyJqQEaeRY09FqFUGRs40qykRXSij26GtKZi3PpQa61EOGTk1m0spH\nBPOWeEusuwb5y+a+IKZ
"text/plain": [
"<matplotlib.figure.Figure at 0x7efcca448050>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAD8CAYAAADzEfagAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvWmvbVmWHTTWbk5729fGiy4jsiqyqcrqUllll9MYg20B\nAly2BAZjo0JCqk9IRgLhtH9BSUgIPlKysQqBZCEaGWELlC5scJVx2dlVdhEZ7XsRr3+3v6fd3eLD\nHGPuc05EZLwbkZHxsmrPkOK8c+5u1lp777XHmnPMMUOMEZ111llnnT2eJZ90AzrrrLPOfpKsmzQ7\n66yzzi5g3aTZWWeddXYB6ybNzjrrrLMLWDdpdtZZZ51dwLpJs7POOuvsAtZNmp111llnF7CPbdIM\nIfzrIYQfhBBeDyF85eM6T2edddbZj9PCx0FuDyGkAF4F8OcA3AbwLwD85Rjj93/kJ+uss846+zFa\n9jEd91cAvB5jfBMAQgh/F8CvAXjPSfPKlSvxhReeX/klvM+//yiZXmY/rP+P8cKL+mi3LcsCAHB4\neAAAmM1mAIC6brhtAPgybRr+xu8hWHuSxBYpSZoiSdO1bSL3UctDkiDhft4KbetfY7v/5oucX+u6\n8s8PetmHJEGSpOuH2dgnxuj9kam/MvVzdd+gnm1cmrVteNzN4wMRgcdMNYbJ+oJPY4X3GJN3XYc0\nQRLYRuiabXzGiCzPAQA5P7X/cq5rX3lb+v0eAGB7a8uO29S2D9s5n80wGNvf9vcu83g/+Z6+r3/9\n6wcxxqsftN3HNWk+A+Cdle+3Afyx1Q1CCL8B4DcA4Pnnn8PXvva7WHnM0HoOdONvXpSf/Mk0ou1F\nO3lwkuKNuvngr229+SDhPaZRHq/RZ13j4cO7AIC/89v/LQDgW9/8DgDg+GxquyQ56rIEAMxm9lu5\ntIk2S+2BGu+MAQBbu7sYjOzfdW1tLhdz25ady/sD9PsDbtO2A2gnqbqpUCyWAICirNbGRP08Oz0C\nAJwcPULJh9xt43bo9ccYDu3BTjmpl5X1qa5q/+xxgojcfz6ZYPWHwXAIAKjqxq9JmtnEk2S8Jzk5\nVU2FwEanqW2T9uwR87kz1ugNOCltW/u2RvaZ8Br1M76Eqgp1bW0ul/xkvzUJjkZbGAxsbBv+bb60\ncZzM7JotihpXnr4OAHjq6Rtsn7Xr5ve+DQA4OXoIABiPhnjxpz4FAPhTf/LL1q+FjUm/Z2PxnW99\nA5//5V8FAPzFv/jr/NsYP3m2/rSEkNx6nL0+rknzAy3G+FsAfgsAvvSlL0a761cnzU9iUoxrH5u/\nR0R/gN9vG9+2aR8y/baJtmKMPqm125rVlT0AIU2RvgsxCQ1p0mz/pkk2bKAPTaxNVeHtt18HANy9\nfQcAUBQlt233DZxo0l5vrc393B7Q8fYOAGC0tYMktXNlid1OfT6QSbQ+zIoSSW79y3r2sPeyvh2f\n7Y0hOEKtCiFKISaOTbTP2eQMod7oH4ekaawv/X4fWzvbdk6fNDlRc+IuFguExK5JxfHuD4Zr50w5\nOdVx6eOsyalR+yo7XpJlSHJrV8O+h4Zjw4kwJDmqyvY7PDwFADx6eGz7cxyHQxvjrdEIo6GNU393\nBAAYs72anGMTsSyWPHZgH2z/IV9mTVUBlW1zdOstAMCysgm15otE164ua5ydnwMADk6sXTlPluTW\nlizPcHpmbS/5cuW75yfMNiHL49nHhanvAHhu5fuz/K2zzjrr7CfaPi6k+S8AvBRCeBE2Wf77AP6D\nH75Lu1i1N3qz8edm5W98N7zL9yVUtY7eYtO0fqJV1AgA8tk1jb+pW5NfCr7NentbRCd/j15eTRN9\n2Sb/USQCC0KOaepd8N3lLtOVCUCSrnx5D1v3m21sw68p/1Ev53j7zj0AwIzL4UqNINIpqxIVEURF\n/6eWc3tXzIc15tIyTVLs7e+vNd6Xv1zWTQ8OgWjHLrjMX8yJ+laRsJbzqfxuds4047l3du0Y58eY\nFQtrX7BrUhby39m+w/EQO7u7a0MhhClXRTHo+dgJaQo5lUS7uq4hRPdl6t6p6EaoIl0XeeZ9WNJF\nkRTWz0roO039PmuInEseRz7OYllxjAqkmfyedtwsJZqku2M0GGLA1UCu+5TjX9Q2RnVZOWLu9Q0t\nDnND1CO6H3Z3beUQywIE6/jBy6/ZcXNr+/6lPQDAyWSBxds3AQDf+u7XAADPPftpAMClfbs/Rv1R\n61qSf/cPgUsN+JgmzRhjFUL4TwD8XzCn5H8XY/ze+23fVBXOjx75MifwP/uiiUsO+fZnLVE0qWhJ\niRVHOmBLmOB+KE1c7HrSPgjukNdEGNbbwC9rf2vDNevBgfQxbpAftoUW5E1Tv7s9H8GOjx/i9Vdf\nAQAs5vIhyq9o28yn5z65aQz7O3zIxrZM1JJwd2eI4WjE46xPtOrg1vYudvgwadnrkxSXh8v5HPOF\nBSUyrvVGw/HaPs89+xkAwL/ypS/gzn3zy75111znWWPX9/DUlo2LJMP2nk3mWnrKj1pxabpYzFtX\niZbsCjbJr+o+4SGMFAJ/GbYvAAZTmojFXD7gKYdA1872TdIU4L8bjm3GezBhv2uO7aKu3R/bOBDg\nZB7shZTlmQd3cvpP+3SB9Pk9TzIMMjuXgoBJ2Aj08d5K8xQZX1qTqbUjTewaLZZ80ZXA5J75l//7\nv/VbAIARJ92f/+IvAQC+9AtfxKeeMd+ormOayL+r51SAIwWw+aw9uRPsx+bTjDH+AwD/4OM6fmed\nddbZJ2GfWCBo1UISkOcDINdSJGvRowc2Nt5OIbzrrfQuJPYjQGYfZB/3GUJI1uH1hzQhprfeeh33\n7j0AAMwXhiAaP7wdv9cboCbq11J7OBy17QGQcrk96A2Q8VplRPO+jOVyP8lqXyorIOX0ISLO5WCO\nsjK00mcAQwhKroJ9ugY+95nncPmG/Xt4yYI9X/jpzwEA3nnnNgDgH33tW47SKrlViOiyni1tt3t9\nD74Etqsi4tQ5C6K+six8nDSWGZfFfS5567oGGqLPKRFqUICPY10n0Doi4xJZx8m4CpKLoK4LlIzi\nC21rqZwKORYlFJoQclVg70wurggkChKJTrRt4zZiJD/jtW/qBkVpy/pibi4GuSiy3LYdDPuOGuWy\nOntkQaPf++r/DQD43je/js/+zM8AAH7m818AAPz0p+0a7W3ZvZBr1VetuNm0gvPnXKu/rHWFhXVX\n2I8blf7kk6s666yzzn6M9oQgzQyD7V20c3iy8u8n38fxcVoI4X2JzRexOf1ur7zyPZydnQEAlkQk\n7m8T6brXoNcYCpLTb4uBH6GsnR1SjkYjXxXIPyaElxApllXjdCEhCAV3WpSaYT4zZCP0uUnjqnnu\nsm4QMqE123+8a8hpf2EIdDDoYzzeZv/C2hgcHx/5cQdEiaIEKXjRF9VnQBTZNIiEmrX8sqJq1S3v\nM9JfOjs332pBnqWI5hGtn3I0JtIkQo9EhoOR/R6b6L5V0ZRS35Y+4aJyf2wQshchH1otVB5XnXO8\nZnNrl9A8u48sS9DjOcT/HPbtO7jqmE1nHkDVY6lYQZbYeJ2ezPGNf/4NAMDL37FwxvVnjCP6mc/9\nLADgs5+xzxvXn8GWOLVCk04lY1Qqlo6kZR5HSFapcpzS/DgbEdYfgXVIs7POOuvsAvZEIM3OHs8+\nHNK0fR4+NJrsG6+/itncfFal0uy05QpiqokMt7YMraX0hSkSvX/JEF1/NMagb8ioLJdsp45oaKjX\n72NMakstAntJug2RRFWVqGpF34milusUn1u3jSo1nU09+rtc2vEenRh6XpTKmOmhJ18h0ZXoNvIP\nFoulbyNkI/pPQd+kUGXdVMgzbuu+X2b9pObvbWIbme8fGUpL69T/ZmMbkRDWbe0QXWVihq+zPpqm\n9jFdkh6mbaYTI6CXRYFKK4WNFUNsndXIe+w7r/GCqbNFss7MiHXtoEzZUspg0u+D4dCvp7KXslrI\nl5F2BIE9nJ/bNgeHhvD
"text/plain": [
"<matplotlib.figure.Figure at 0x7efcca4f9150>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAD8CAYAAADzEfagAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvWmvJVl2HbZOTHd8Y+bLsaqyqrq6q5vdzW6yObVJwYQk\nCoZtiP5kyIYN2TDATwZkwIZF+RcQMGD4o0FYBgTYgG3ANiQbEgyZEAVLtghO3U12dXcNXVlTDi/f\n/N6dYjr+sNfeJyLuy6x8XSwyZcT+kDfvfTGcOBFxzjp7r722896jt956662357PoL7oBvfXWW2//\nMlk/aPbWW2+9XcH6QbO33nrr7QrWD5q99dZbb1ewftDsrbfeeruC9YNmb7311tsVrB80e+utt96u\nYJ/boOmc+9eccz9yzr3rnPvNz+s8vfXWW29/nuY+D3K7cy4G8DaAXwPwMYDfB/DveO/f+jM/WW+9\n9dbbn6Mln9NxfwHAu977HwOAc+5/BPDrAC4dNLemI3/r2uanHNL95K25ZNfLjlZWNQCgrmUiydJn\ndM9TmuOe8a0oSgCAhxzfOYckSS7do+ZkFjl3ybnkh4rtnOcVAOD4bI6qLAAAw9EYAJDwGiLneM7m\nWWR/nTejyPG7R1XJj2Xd3rYocgDAeDKR40URikLOmcYxAGAy3ZD2OVnIRHGKJOaixksfL5dLAMBi\nds5ze2vjZDyVc4xGchyeU9sAONQ8zkKPs1gAAFarnO2KEcdy7aXXa5fPmvd5MBphxOsI96F7B9kP\nxQrzuZyjXMmnr+TcMfstisLCrarknhRlyX4r+XuNAFR0v077ammfbNa8b+G50G299+EG6if/FrPP\nnXOoKj1m3dpUTft+NBxYH5a8hvYT8Cnm1v6ztr/1k15DVcPzmj+zXaWx3ffK48B7v/dpu31eg+Zd\nAB81vn8M4BebGzjnfgPAbwDAzd0N/De/+TdgV+GcXbRr/ga0L9Q6qP3QBWs/RJ3zr/12dCIv8HIl\nD8vd29e4bfshb+1vH67d3uZ5ufujxwcAgJwP43AwwPXrOwDCAJjyQc9L2SZLEnupGgcGAJyvZJ8/\n+OAUAPC//V9/gKMn+wCAN7/2dQDA9ds3AACjQczjx4g4mOkLpC/4aDAAAJRlibPzFQDg4DzhtrLN\n48cfAgB+5hd+AQCQjKZ4uP8IAHBjawsA8O1f+VVpXzSU69y9g50N+b/L5wCAd97+AQDge7//zwAA\n42yJyTAFAPziN34ZAPDTX/sGAODs8QcAgJht8C7CnAPX99/+EQDgO9/7UwDAj9+Xx244nmC8cxMA\ncFzItSdOrk8H6te/9k18/eflOnavXZf22cDHgaeWCeHo4X388Xe/AwB48q6cy5/IuTcmvM7JEI6j\n0fm53JNH+08AAJ88PAIAnJ7MkHMAc7wPw4H08TDLpH0cnJe5h+NgHvMz54AWZdJXdVHYROk5QOv9\n3d6RyStJUpydznhMDvh83nSwGg+lb776pTfw3nvS38cnJ+wJDubWMx6I+Zv3reNFacxrc0AUt/b3\nvO5sLBO6S+Ua8osLVLzmurr6yldfsyhyOhRYY+unraQd4KL2O+xz/8HznO/zGjQ/1bz3vw3gtwHg\nzXs3vVyF/RVhNOpedGMA8e3/hKHNdba9DBZ2pyQHKNJ66vHWT33Z3+QPzcGz85trvpiXD4iXDfQ2\nM3tFmhw8Z/LSFGUFB3lQs4woLZeX7PXX3gAA7G5tYjiQh1ZBpKKh3W0ZwCPUeHwiL9k/+l0ZII5P\n5OWv+TQOBnL8eVlaZwwGMnhEibwMZcWBJ4ltMJmdy4BVFByoR7JP7JdI2P+3b8hkNR7JcbZe4sCf\nyN/3j0/wzvfeAQBcXEg7v/7lrwAA/tVf+jYAYLKxhYcyPuOgluu99fKrAIAnDz8GAPzBP/9dPH54\nGwCwc00nSJ38xHytKP4cm9uyTXVD2uMzuZZqIYNLuVohbo+5KCv5YZXrBOURs380pKCT4mop9yrn\nygFxhjhj/xDFR14HKyLGukbEvw2JzNNYjq/ofn4xR8XBMSL6rsBzcMKM9T3zFUYjGbyH2a4ch4Nf\nyUkrTTIcncq1n8+l/8OrSPScDmwAqxT9s29Ltt1xsK99hbqLS66AFONErnMwSAx8FOxDV+vh9Poa\nx/8JPZOfVyDoEwAvN76/xN9666233v6lts8Laf4+gC86516DDJZ/A8C/+6l76dTkLl8+P2PHzkfb\nD+Tg7bfuLk0wGvxET9somFvbxLW/NtGy7RR1d7rkuE+f/sLKQ7ZZcRY9uZClTVV5xANBGelQ0MLm\nhvjs3nxd5rCbN29hNNkGAJREHzMujW5cE2RRVwWKHwsam89lmXlwKMv+yUSWcQMu5U/OLmw1MCBq\ndLos8/IZx5G5Apa5oIvJVPyWujStc2+IZjoZcT/5WxJPeN1yjNKnuPfKFwEA3/jazwEARlzyDYYj\nbuOA+48BAFEufbKxK0g6pZ/3+tsvo8yl7YuF+CeHI7muWNFVLm6K5XyBCZHcKeHLtZdfkT46kXOv\nzo+xWAryenhwBgD4ZF9QaF7IeSrvDC1WlSD8WdG+57qSSKIIKZflpfn86NPUzyiypfvOdXExTNRP\ny/765OMPsSjkOvQZtAUc3RHql41RouC2ik7rWFGy9j9QqI+UgDUiqo0zDimRM/+1IxKMiYAdt9Xl\nOmofPFndx189bFFYifqOX1eX9ItFaQfQbfxTxxFnLoWrxks+l0HTe1865/5jAP8ngBjAf+e9//6n\n7ddcTIe+a1+YBVEuc25aB7X3qb2tvBu7rN+d4Gw3TN9uGC7rXn/ptt47O4frDuLN45nvdu0A9rV7\nztoCQPKXk3N5UYsiV1cT1Kkz4SCgQS3nHBL6zlByyRbJoKk+nggJlisGMFYrPSkAYDjsLMHryh7e\nEYNPXju71hcyRuzk7RpzcJvTr4hazpPEEQb00w3YPvWjLvlyJbEODrexc+1WaxsdgMpK2rsqHQq+\nTEmmy2Fp15CBql/6q7+GJa9P26VBmIJujfnpsRzDAQv6UVdLBoQy3t+Ek8NojOMjua73PxT/9eL8\nQvqLS+h0OAxxGy797eXlgKaTT5xmMNdTVbSutxlvmXKQzNKMfcHAHP2LWZrBeU6sDKrV9JnrwBbR\n/5wXNc7OxK8xX9B/2hl4vP0DW6taEJHL/abF7H/zjbIN6oN1df30pXLz1FE7mBjccuGd0cHSXNPq\nu73M26XvJQfzus7XN7rEPjefpvf+HwL4h5/X8Xvrrbfe/iLsLywQtG6+sTqPGr92/mdTRgOPrq+V\n1z46q+i17655Co3Sdo53adDHudY2rT+tTW+d6bQ1gyuSflbwSkyRZs3l78VMkGZVFphsyLLX2VJX\nEFTGSGWcpKF3eZ0lKUOK+oAI5+eCNlZ0qCs6Gk/k+BGXhHVdI+W5xmMuozUYwG2ixtJKP88uGETh\nOdM0xojBpQHRbFXKzD+NySQgckoHA7u+kqilWM3ZHgY2IqA26g1RdmdFMhyNUXXWgxndDnUpvXRO\nNDSbz7Bi4EJRo9GnZoLiDg6e4MFDcQloAELPHWdy3CSJUPOYPtVlqrRhyeAYnBx/OZ/ZclfpP4oQ\nDdFVFebsS0XA6mKoxkSsSRKeGUWYGhhSlkAk51nlhfWt0tUs0s6+iuPIVhdxokGoAbtWvi9WS1Ts\nBF1N1fac6Q8hmLVmnUWjvyRy02FYcVXWdpN1V4Ku+TuvE3o/i/VmXGZ9GmVvvfXW2xXsBUKabQuo\n4PLZBfCXMIradJG2dfwfzwCNOrMaUmwCTtuPDvWnOWNa0FWP321wvY5mLwWa7QBVxXYtiFjOiHTq\n2mNrR4I5Kf1IijSV6Ox9jXwlyHTFoExOX52S1Mva48mh8ArnMwloVPQVqv9MgwS+qpElijTHbJ+0\nM6KPLYrc2uy8JA1lMCSRfTDFxqYEMhCTL6qBCCKenN/z5cLQWE3EtWBAZ1aQZlMAK3KqEqJs9dka\nuyaODNWWRJEJXwklu28
"text/plain": [
"<matplotlib.figure.Figure at 0x7efd3c4af990>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAD8CAYAAADzEfagAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvfmvJNl1JvbdiMjI/e1b7b1vXCRyuIiiZHEkUpqhNMOZ\ngUeyZ2xojAH4kw0ZsGFx5i8gYMCwfzJAeGwIsIyBMNo4EkWJZG9kq9nd7K5mr7XvVe+9elu+3DNj\nuf7hfOdGZlY1u15t/dgdh2BnZb7IWG5E3vudc77zHWOtRW655ZZbbrdm3gd9ArnllltuP0+WT5q5\n5ZZbbnuwfNLMLbfcctuD5ZNmbrnlltseLJ80c8stt9z2YPmkmVtuueW2B8snzdxyyy23Pdg9mzSN\nMf/IGHPSGHPGGPONe3Wc3HLLLbf7aeZekNuNMT6AUwC+AuAKgFcA/NfW2nfu+sFyyy233O6jBfdo\nv58DcMZaew4AjDH/EcDXANx00iyXi7Y+VUMyHAIASqUSjK+nJpO6Tu69/kBee30Eng8ACIsFyHEE\nOPu+fF4uFQEAM7PzCILCXby8nx/TcUuSGABgYGA8Gac0TdxnAOBx3Iwx9/s076klvM5GYxcA0On0\n+JfM0dJxKoTy2ezMNAAgDLKfiMdn8sM2Ph8ts+/xb+DVV49vWmsX328P92rSPATg8sj7KwA+P7qB\nMebrAL4OALV6Bb/3e7+F3WtXAACPPPoEyrNzsqFNAQBRLD/64++eAQCceOMk5qZqcrCjRwAApWIJ\nADA9XQcAfOyxRwEA/+Jf/mvMLh6Q40If+I/Ggx9FshA1d3YAAEEQICzKYtLttN1nAFCpTcn7QuFD\nNTHstmSy/PZf/zUA4Mc/lrXb80qwVq4zjmWcVg5WAAD/8l/8DgDg4OyC7MQAtSmZSAOOn/mIPEM/\nnzbpQac3eU3HtjWmevFW9nyvJs33NWvttwB8CwAWFmdsp99FyJU8KJYAXx9MQQlxHAEAWu2O7CAF\nfF/QY8oHPyFaGHKCPXtRxuD4T17Er/76VwEAYbF8T69rv1mhEAIAZhfkxz/o9RFHgtaLXGTUeu0W\nACAICyiWZPLwgw/sEblja3fkep5+9jkAwCs/OQkASK2MCVKLlM+MF8hn62vyfP3nv/47AMDvfPUr\nAIAHDhxCZ1cm32JVxqZUqQLIPJzc9pPpgmYn3o++3t6id6/u9lUAR0beH+ZnueWWW24/13avYMQr\nAB41xjwImSz/KwD/6r02TpMEnd0GFiviHnphEYgZd/JkNegPBWl2W+pSeggKgjQ1XmdTWQM0PtVh\n/POFF1/AyoHDAIAnP/kZ2a3384ugbsc8xn/L1SrSRBBmvydjPBz0AWSx4CRJ0aZLGxQE8ZfK5bFt\n9qsprmh3Wnj62WcBAM89fxwAMBjw3BnTtWkCS7RhUrponlzvpUty/X/+F38DAPjt3/4ynnj4MQBA\nxHGLGfooM6whqJz7u9sXlttdspshzb0lw+/JzGGtjY0x/z2AvwXgA/i/rbVvv9f2cZxgZ7OJmWMS\nxxxaoDKBplvdLgAgjWTynJubRbkm7pHVHwGvPU0ZB2UCYH17Bz/60bMAgPnFFQDAyqFj4wf4CJkm\nfMo1iQmHTJj1dYyTBD4TZzHHu8WJtVgW17RYkonX8/aXa9pqNwEAzzz3HP7+71+XDxm+CQJ5jeLM\nZdN/6ULr8QNDd311Vcbkz/78u/jtr8ok+alPfBJA9iy2GxIvLlaqKHF88CGKCX847fbvzz2DW9ba\n7wD4zr3af2655ZbbB2H7wkeN4hTXNlroRJK42Wx3cXBFEhczs5Kx3N6W1VxxzdTMNMISkzpGXSuu\nHt44TE3SFBeuSEj1xReeBQB8+R99DQBQm5od2fKjZXrNSseq1cXNTOIE/b64oAqYPG4zHEjIQ13T\nUrkMn8km7wNEV00izKefeQYA8NJLr6PTkWdm2JeQjucLqyK18mo837nnHpM5mlQ0TKx6nlz35uYQ\nf/nt7wMAIoaKPvcPJNRTYQij22oiJm2uUpdj+P6++Il9hO1nPZO3x6TZX75Vbrnllts+t32xDCZJ\nimazg05XUMz2dgOXLq8CAGZnZ7iNrO4hV3VT8JGSZ+URHWRxJHl1MU5rMSBl6c13JLQ6v7AEAPil\nL/667HeCfvORNI6fXwhQCSTeGTFu1+9IbE8HVYq+gEGvBzDhpnFOTdDdD66nxjC/T4T58o8l6dNo\nbODMqZ8CAHxig8WVh+X8NKlVCEboQuPPkMY6E41xegXs7krC8a++IxQm5Q7/8ueEglyfncOgK5Sl\ndqMBAChVJe6uz9eHif/6UbUcaeaWW2657cH2BdI0AHzjuWxkJ47Q60tsaLchBOVSWRBmjRlf0+66\nShaGNJFwmyCUVd33WYKZpEgIGdo9yQK/9MqPAQBLi8sAgEc/9ovwvf1Np7mfpogoDCVeWSiMxzT1\n1RjPbdvtCMpy2XlmkoPC3X/Mmm0lrj8LAHjlZUGVO411AMCZE2+iy21WloUxUa4IwoyZ8U/S1MVh\nDU/R2oSvLMl1sc4UQgQBmruyzXe/+0MA2Vj8yi9/EbWaxDIDftZj1ZXGOkvV2r6nbX24bZRmlMc0\nc8stt9zuue0LpAkA1qYZIjEerNacD2XFDlhO2WG0qdVsO35hGCqnsMq9ycpRIEryrHWIyfdlnVjb\n2gIAPP/8DwAAc4vLWFo+xK/ncadJUzSpccsC66/jwQB9xjRdPJA82XazyW3Du0qO32018YNnngUA\nvPIKEea2IMzTJ94AIEUQJd7zGcbFNa4YD3l/bepgQ6plyBoPTxgv5+kaz8uC5IzndjryTH7ve+K1\ndHt9fPkffgkAMOUI7/Js9ok4240dVOgt6Rh+NLkbP7+WI83ccssttz3YvkCa1qZI44Gr3PCCEAWu\n0Fr+Nz8rKHLAeX57q4UBEU4QyDb9rnALO6xsUXRTDAqOwxklZX5HUMjJs6Ka9PKLz+PXvyLKNqr2\nk9t7m8YCw1IJARFdVpYp98ULtCwzQYflr0V6B+oFqEzdrWCtJks7f/D0M3jxx28BAHo85uXzch8V\n0RnfoFqTe330sHgQO+1xlOv7xjkVhijSMoZp6dGkijiN52LelqwNJW0MhrLtc88dx4Cxy9/6srAy\nZmekyq0yJc/UsN9Hj6IzqkBVqgjy3G/VVR8uu1nJ5O3FNPfFpClmYVn/izhCxiKSBymi/xRzm0JQ\ncO5SSupHryPbDAeSUCoU5OEsl4sYMMk0NS0Pb5GuUYOD+OyLL6BMcvcvf/FLAIBSSVVscvfpZ5n+\n2CuOXiNjO+xL0i1JEhiGRaKBTBRDTi7qohbD8D3HebfJyZK0old/8iaSiBQoUtEWlx4EAPhkpU9V\nQhw7Jpoxjz32CADgtbdErdCXr4z9VFziiyGi1OmQMjGUWlhmHD2nR8qJlgt5HBv8+EWRnRvyOv/x\nb30ZALA0LxS3YrnsKFnOZd8VelKZ41cI1W3PbT9avrTllltuue3B9gXSNNBAu6CEJEmRpIJShkOu\n4nSTHM3I8zIlbRLeFZ4qbSSgG+X7wM7WNgCg2RT0GVaZ0AhlHzudFta2JTm03ZBtf4mk5ZWVo7K/\nQnEEDeXo871M71HAhEcSx+iT6jVMxStQV3egLn2vj2JZ7okm7VQ8+PtPC8J87bU3AQBbm9dw+dJp\nAEC1IkmemdmDAIAHiC6/8qVfwRw1RNdZgquOTMp7GBgPHsM2+jc/kOetyGsYsmRyGCXOy/GMdhNQ\nRSMtxfQdGf7V1yRc0CPi/Cf/WHQ5D64cdBqlFYoaD1my2mMIIy7JMYvlSu6y33W72e93b2Oc35Hc\ncssttz3Y/kCavo9yre7iXGmSAkQkCcsfGR6DtvrxfQ9sDYSECGAw0FJLQTH1OcaIahV0uxI3Ggxl\nR710yP3IEPgesLMtyOY//fmfAgDinpCjv/irvwYAKFXrqFLzs0KEk8c739/8IECFMn4FxpZ7TNq5\nskzPc+T4jc3rAIAXXn4
"text/plain": [
"<matplotlib.figure.Figure at 0x7efd3c4afc90>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAD8CAYAAADzEfagAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvfmvZdl1HvbtM9753XvfXFWvqrq6eu7mPJMSKYuTKMWU\nY9mSkxgyEIP+QUEcwEEs5y8QECBIICCACMsJISuRYlEURVOUOIgUyW52N3sge+7qmqc3D3cezpQf\n1rf2efc1RVax1VLZORvovq/uPeM+++z9rbW+9S2TZRmKVrSiFa1ot9acv+sLKFrRila0/5RaMWkW\nrWhFK9pttGLSLFrRila022jFpFm0ohWtaLfRikmzaEUrWtFuoxWTZtGKVrSi3UYrJs2iFa1oRbuN\n9qZNmsaYTxpjXjXGnDfG/OabdZ6iFa1oRfvbbObNILcbY1wA5wB8DMB1AN8H8E+yLHvpb/xkRSta\n0Yr2t9i8N+m47wFwPsuyiwBgjPkDAJ8G8CMnTcd1Ms93UTIuAGBpcR7lWgUAkMEAALrdLgDAdwQc\ne2GITrc3811QCgAAaZIAACaxfJokRX2uDgAYTsYAgFFH9m00GrKN72M8mci5Ol0eJwUANOs1AMDc\nfBtZKsf0fTlXwnM5vAY/LMnxYJBBFqQpz9ntyzl3dvZk3zjh3QGuK49CFzG7mBkDx3Ar3Zh/BIEP\nAAhLZbneNP/NcJ8sk3vI5Ee4roNySfbzPHvAmWaMgeFxUl6Hw+tz7C7Z6+4TR9bffD3ObD/1+wMA\nQMS+LZcr3MZBlvGcKfsAsg271t5+GAaYTuVZxXzGc805+S0I7T3oM9E9J9zn4GBf9o1ilEol9ouM\nPf13GOpxXHu8v66laSzXm6b5OY1zZD/9zOwYso3baJ/jdfv8TbS/DhzJ93oPBsaOvSSOAACuL33h\nOK7dK4mnsjfHl23sL9dxYBwdg9nstjq0Dx0vH6/ZzL9/dB/cCtDLjnz+uMPIH08/89xOlmWLP+nI\nb9akeRzAtUP/vg7gvYc3MMZ8BsBnAMD1HCytzePekkxO//2/+HU8/P53AQAiduzXv/YNAMBKKJPV\n4r334Ev87kRYlc/77wIADA4OAABX+Ol1hvi5T3wYAPDUhdcAAM995ZsAgF/4xMcAAP7yMl68dBEA\n8LWvyHHH+30AwC//3AcBAL/4X/8aRiP57tjxUwCAg24HAFAN5LqOn3kAAOB4HqJYBuK1y68AAL7x\n2HcAAP/2s78HANjf6SLg/c3NNeV+E9knjmJ2lItSmROxDkK+kKeOHwcA3H3/gwCA4diBybiA8KWf\njody3OkIAFCvlfHw/csAgKUFmWw524K7wnM9eJ5MrOOJXEel3pJPuU04yAd3ypdBFxmdLLNE/ojT\nBH0uUt999EkAwMaB/Pvht7yd91lHHMvBh0O51jiRSa7My9RJ/uyZNVy+dEH6kBPgJ3/pFwEAZ07d\nDQAI/QAhFzCd8C9euQwA+MKffB4AsLexhfvul+fVaMqi+sB99wIA7rpLPsOSjC3XDexEmDe5v3Fv\nFwAQTfoIOYYdnWg8nXC4KKYx0mmf/cSJypG+Ni4nam5rHNdOQsa+4T/NRJoig07U3N9OZHINk4Hc\ng/F8ZJFMiHvbNwAArZWzAIBSZY5HMOjuXAIAxPGIPSF94wRy/9VKBT4XMJ2Qp+MB/y3nDstN3mfe\nB1kq59YJ2jgBz3m476dH7k9/y5DPhAmPF/G2uTDZuTSV//RvACZYvYJbaG/WpPkTW5ZlnwXwWQAo\nVcKsVi1jMJLO2N3ZR78nL1XKNyZjRx87fhIAsHrqNCo1GdDrG/LA4+vykDoDeTgvn5dJ8Gx7AY+/\n/CoA4Gvf/DYAYI0vw8iTDn/0O99F2JEJ8MHTMhmNKoIIl06sAACee/4HaDflQa+t3cUbIZpVJObp\nCxBgwONtbm4AAB5/8lkAQK8jAy1NAIdo0S/LtVe8Ko8rH5NplKNEHjtkn3j83N4lMk59eLpiE9kp\nYmrUZcDXygHmqoKuy4Fsm/D4MSfsJEkQc8IfjmXQZRy8WSz7hL7crx8EcFy+MEcnFY4uL3XhtuX8\n73vfOwEAr56/CgAo8fri1IfjKHrnYJ5KJ/g+u4R9PY0Tez/nXpHn+v/8/r8HAKyd5vhYWcbayVPc\nUV7ARx99XPprYwsAcPfZsyhxDOk5u7syUaxnMv7KVZlMy7UFhFV59kFJ+s915cKCsmzjhxUYIs00\nkgk/GcsEmXLBM14IJ2ywg/iQE+njNJHxH+ukCoOg0pb9TI7Kbr8Zu5hm4GKsiI6TqOfLWEriMeKp\nWEYRJ8/pRMZrudLkrimCQLZ3OQ5GfRmD4CTleCGMQysikYXb5dgOXFoUyYjXF8LY2UwReon/1oV4\nCgO1xjg+jC7ch8fdLMJUVK/v58yakx1F9rfW3qxA0A0Aa4f+fYLfFa1oRSvaf9LtzUKa3wdwjzHm\nLshk+WsA/qu/dusMSKcpYi4S0zTDkKgz4yqwN5CV/+rNmwCAxpkTGBAZbXXEDHd7soK/dkFMh1fO\nXZbjnZjgsQtizl15Tb5z1o4BAPr7su+yZ/Dghz4AAPjei88BAD7+DkFF7/zYJwAAl69expWL5wAA\n5869DACIaQYMa7I6N1cEJQRhCbtb4qHY2xcT8tIFQf9JJPsEvofJVO5hf09Q6fyCIAv1U5bqdZTU\nzCTS3KPb4YAmr+8L6qpWqqhVZTVXX6HneTOf1VoJjab0U7M5+/gj+rCmkynGY7mf8Viuz/qJx/IZ\nEY14E8/68fQcAaGhx9XddT3rWlhYmpc+4Hp9MJTjjSYeslS2T+jqSIkS6G4EjQ2MxlPU6GcO2Dfr\nNzYBAFcvy/hwHCChWaho26XZG/Ne1nc2sLgsrorjK2JN1EqrAIB6WU6a0Xc37u9bl4VXEmSp5qoX\n6rVUEZblO78k58ys60KOk8YTxJEgLPXpOR6fbyDHNYo443Hu53xDzeQ+Vvj2q8PNC9VSCi0KLYVE\nn1MZZ0kiz84YFyDKzojQ9fmqr/8wMs4gf3sB3UxmFimmaQxkPL8z60tOaa4nUd8iVevW0OMoijdO\nbo7ruVO9Pu2KN96fb8qkmWVZbIz57wD8BQAXwL/LsuzFv277IAxw8u5TuP6S+Bu/+4PnsRHLS3ny\nntMAgAvrYuLu3BST+aBWxjf/6lEAwO4NMc/rr4o5rgB9nyb+4OUB6jQJGjQnepysnn5BJr/333sW\n3nF5YfrfFhP+3n8ok+Yi/WRBvY5mgy9IRT43L58HADiumCBXXn5GrsF1rVmeTeVBrs3Tb3kg2x4M\nRpiM5AUajeThejR7aq0FAECCDPt9mYiHPZl8q5yUGjR56w3xN9aqdRsc8jnxBBzEvi+DsFz27SCj\nyxEezesg0G0DlCsV3pcMYo9+Y8Ry7YmalFmGKOJkO+UEw4Gpk4zruq8z4QNOKh739WIXqdEJPuN+\nDLBkOnnK78PxFHW6GOZa0qfXb2wDADo96cckjTEdyfN3nIR9K/sP2J+DXg8u+7Jal/v7/oklAMB9\nZ0/L5z1nAACnjh/DQksWtDIn3WQkJmkUMWAVhKg3JY5Qoinvc4J1af56pcbrJ9Kp+gXZtxoIM7BB\ntjf+qr8+IDX7Kd87ToigKmOv7cp46OxcBwCMBvLuVWsLcD11pcg96IIbcmFy3JIdIwknPpdj0foS\nOaEl0wkQyvj03RKvSvo4d1mMc9+vM+uCSjgm00ziIwBs8FTHztH7nO2K2+vdN82nmWXZnwH4szfr\n+EUrWtGK9nfR/s4CQYdbGAa4657T2KPp/NCJYxgRDTz76BMAgNdeEUS3tS2r+18+8wPsbwu6SAiZ\nRqPZSJ6hE/l0q4lPfkCitI1VQZNffuIpOd4VMedOfupTeOE1Mb3XWmKGrJyRCGpMp/5k0EGDv2kw\nZuEdQgoIq4I8wZW21+kAE1lpKycEEf7qP5aV9g++8EUAwPe//wKQ0qypCNLpjmSfwdaOnHM8QXdf\nVviMEfUz7xZmwfETEuh
"text/plain": [
"<matplotlib.figure.Figure at 0x7efd3c4af750>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "NameError",
"evalue": "global name 'get_features' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-11-b167e36c668e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-10-6d06f5d0c974>\u001b[0m in \u001b[0;36mget_batch\u001b[0;34m(idx, X, Y, W, dataset)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mW\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_features\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mX_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: global name 'get_features' is not defined"
]
}
],
"source": [
"_,_,_ = get_batch(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}