Sketchback/Project_1.ipynb
2017-04-09 01:43:24 +02:00

219 lines
7.2 KiB
Plaintext
Executable File

{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd \n",
"import cv2 as cv\n",
"import os\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"from keras.models import Sequential, Model\n",
"from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D\n",
"from keras.layers.core import Activation, Dropout, Flatten\n",
"from keras.optimizers import SGD\n",
"from keras.utils import np_utils\n",
"from keras.callbacks import TensorBoard\n",
"\n",
"np.random.seed(1337) # for reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"((1005, 480, 640, 3), (1005, 480, 640, 3))\n"
]
}
],
"source": [
"num_images = (1005, 480, 640, 3)\n",
"X_train = np.zeros(num_images)\n",
"Y_train = np.zeros(num_images)\n",
"\n",
"for i, file in enumerate(os.listdir('ZuBuD_Sketch')):\n",
" file_path = os.path.join('ZuBuD_Sketch', file)\n",
" img = cv.imread(file_path)\n",
" img = cv.cvtColor(img, cv.COLOR_BGR2RGB)\n",
" X_train[i] = img\n",
" \n",
"for i, file in enumerate(os.listdir('ZuBuD')):\n",
" file_path = os.path.join('ZuBuD', file)\n",
" img = cv.imread(file_path)\n",
" img = cv.cvtColor(img, cv.COLOR_BGR2RGB) \n",
" if img.shape == (240,320,3):\n",
" img = np.resize(img, (480, 640, 3))\n",
" Y_train[i] = img\n",
" \n",
"print (X_train.shape, Y_train.shape) "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_val = X_train[:100]\n",
"Y_val = Y_train[:100]\n",
"X_train = X_train[100:]\n",
"Y_train = Y_train[100:]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_4 (InputLayer) (None, 480, 640, 3) 0 \n",
"_________________________________________________________________\n",
"conv2d_22 (Conv2D) (None, 480, 640, 128) 3584 \n",
"_________________________________________________________________\n",
"max_pooling2d_10 (MaxPooling (None, 240, 320, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_23 (Conv2D) (None, 240, 320, 64) 73792 \n",
"_________________________________________________________________\n",
"max_pooling2d_11 (MaxPooling (None, 120, 160, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_24 (Conv2D) (None, 120, 160, 64) 36928 \n",
"_________________________________________________________________\n",
"max_pooling2d_12 (MaxPooling (None, 60, 80, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_25 (Conv2D) (None, 60, 80, 64) 36928 \n",
"_________________________________________________________________\n",
"up_sampling2d_10 (UpSampling (None, 120, 160, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_26 (Conv2D) (None, 120, 160, 64) 36928 \n",
"_________________________________________________________________\n",
"up_sampling2d_11 (UpSampling (None, 240, 320, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_27 (Conv2D) (None, 240, 320, 128) 73856 \n",
"_________________________________________________________________\n",
"up_sampling2d_12 (UpSampling (None, 480, 640, 128) 0 \n",
"_________________________________________________________________\n",
"conv2d_28 (Conv2D) (None, 480, 640, 3) 3459 \n",
"=================================================================\n",
"Total params: 265,475.0\n",
"Trainable params: 265,475.0\n",
"Non-trainable params: 0.0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"input_img = Input(shape=(480, 640, 3)) \n",
"\n",
"x = Conv2D(128, (3, 3), activation='relu', padding='same')(input_img)\n",
"x = MaxPooling2D((2, 2), padding='same')(x)\n",
"x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)\n",
"x = MaxPooling2D((2, 2), padding='same')(x)\n",
"x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)\n",
"encoded = MaxPooling2D((2, 2), padding='same')(x)\n",
"\n",
"# at this point the representation is (4, 4, 8) i.e. 128-dimensional\n",
"\n",
"x = Conv2D(64, (3, 3), activation='relu', padding='same')(encoded)\n",
"x = UpSampling2D((2, 2))(x)\n",
"x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)\n",
"x = UpSampling2D((2, 2))(x)\n",
"x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)\n",
"x = UpSampling2D((2, 2))(x)\n",
"decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)\n",
"\n",
"autoencoder = Model(input_img, decoded)\n",
"autoencoder.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 905 samples, validate on 100 samples\n",
"Epoch 1/50\n"
]
}
],
"source": [
"autoencoder.compile(optimizer='adam', loss='mean_squared_error')\n",
"autoencoder.fit(X_train, Y_train,\n",
" epochs=50,\n",
" batch_size=64,\n",
" shuffle=True,\n",
" validation_data=(X_val, Y_val))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"Y_test = autoencoder.predict(X_val)\n",
"\n",
"fig = plt.figure()\n",
"a = fig.add_subplot(1,2,1)\n",
"imgplot = plt.imshow(Y_test[0])\n",
"a.set_title('Prediction')\n",
"a = fig.add_subplot(1,2,2)\n",
"imgplot = plt.imshow(X_val[0])\n",
"a.set_title('Ground Truth')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}