Created
July 16, 2017 22:22
-
-
Save tomtung/c030219cdb731ad67be00cb049b5dc22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Using TensorFlow backend.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import keras\n", | |
| "import numpy as np\n", | |
| "import random" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "N_DIGITS = 5\n", | |
| "INPUT_LEN = N_DIGITS * 2 + 1\n", | |
| "OUTPUT_LEN = N_DIGITS + 1\n", | |
| "\n", | |
| "CHARS = list(' 1234567890+')\n", | |
| "CHAR_TO_INDEX = {\n", | |
| " c: i\n", | |
| " for i, c in enumerate(CHARS)\n", | |
| "}\n", | |
| "\n", | |
| "TRAIN_DATA_SIZE = 600000\n", | |
| "TEST_DATA_SIZE = 100000" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Data Generation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "A random number: 3\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def generate_random_number():\n", | |
| " return random.randrange(0, 10 ** random.randint(1, N_DIGITS))\n", | |
| "\n", | |
| "print('A random number: {}'.format(generate_random_number()))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(8, 27988)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def generate_addend_pair():\n", | |
| " return generate_random_number(), generate_random_number()\n", | |
| "\n", | |
| "print(generate_addend_pair())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "An example: ('12+345 ', '357 ')\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def generate_str_example(x, y):\n", | |
| " input_str = '{}+{}'.format(x, y)\n", | |
| " output_str = str(x + y)\n", | |
| " \n", | |
| " input_format_str = '{{:{}}}'.format(INPUT_LEN)\n", | |
| " input_str = input_format_str.format(input_str)\n", | |
| " \n", | |
| " output_format_str = '{{:{}}}'.format(OUTPUT_LEN)\n", | |
| " output_str = output_format_str.format(output_str)\n", | |
| " \n", | |
| " return input_str, output_str\n", | |
| "\n", | |
| "print('An example: {}'.format(generate_str_example(12, 345)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[(11, 12), (6, 12)]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def generate_example(x, y):\n", | |
| " input_str, output_str = generate_str_example(x, y)\n", | |
| "\n", | |
| " input_ = np.zeros((INPUT_LEN, len(CHARS)))\n", | |
| " for i, c in enumerate(input_str):\n", | |
| " index = CHAR_TO_INDEX[c]\n", | |
| " input_[i, index] = 1\n", | |
| "\n", | |
| " output = np.zeros((OUTPUT_LEN, len(CHARS)))\n", | |
| " for i, c in enumerate(output_str):\n", | |
| " index = CHAR_TO_INDEX[c]\n", | |
| " output[i, index] = 1\n", | |
| "\n", | |
| " return input_, output\n", | |
| "\n", | |
| "print([array.shape for array in generate_example(12, 345)])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "training_x shape: (600000, 11, 12)\n", | |
| "training_y shape: (600000, 6, 12)\n", | |
| "testing_x shape: (100000, 11, 12)\n", | |
| "testing_y shape: (100000, 6, 12)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def generate_examples(n_train, n_test):\n", | |
| " n_examples = n_train + n_test\n", | |
| " \n", | |
| " addend_pairs = set()\n", | |
| " while len(addend_pairs) < n_examples:\n", | |
| " addend_pairs.add(generate_addend_pair())\n", | |
| " \n", | |
| " inputs, outputs = zip(*[\n", | |
| " generate_example(x, y)\n", | |
| " for x, y in addend_pairs\n", | |
| " ])\n", | |
| " \n", | |
| " return np.array(inputs[:n_train]), np.array(outputs[:n_train]), np.array(inputs[n_train:]), np.array(outputs[n_train:])\n", | |
| "\n", | |
| "training_x, training_y, testing_x, testing_y = generate_examples(TRAIN_DATA_SIZE, TEST_DATA_SIZE)\n", | |
| "print('training_x shape:', training_x.shape)\n", | |
| "print('training_y shape:', training_y.shape)\n", | |
| "print('testing_x shape:', testing_x.shape)\n", | |
| "print('testing_y shape:', testing_y.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "HIDDEN_SIZE = 128\n", | |
| "BATCH_SIZE = 128\n", | |
| "MAX_N_EPOCS = 1000" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = keras.models.Sequential([\n", | |
| " keras.layers.wrappers.Bidirectional(\n", | |
| " keras.layers.recurrent.LSTM(HIDDEN_SIZE),\n", | |
| " input_shape=(INPUT_LEN, len(CHARS))\n", | |
| " ),\n", | |
| " keras.layers.core.RepeatVector(OUTPUT_LEN),\n", | |
| " keras.layers.recurrent.LSTM(HIDDEN_SIZE, return_sequences=True),\n", | |
| " keras.layers.wrappers.TimeDistributed(\n", | |
| " keras.layers.Dense(len(CHARS), activation='softmax')\n", | |
| " ),\n", | |
| "])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "_________________________________________________________________\n", | |
| "Layer (type) Output Shape Param # \n", | |
| "=================================================================\n", | |
| "bidirectional_1 (Bidirection (None, 256) 144384 \n", | |
| "_________________________________________________________________\n", | |
| "repeat_vector_1 (RepeatVecto (None, 6, 256) 0 \n", | |
| "_________________________________________________________________\n", | |
| "lstm_2 (LSTM) (None, 6, 128) 197120 \n", | |
| "_________________________________________________________________\n", | |
| "time_distributed_1 (TimeDist (None, 6, 12) 1548 \n", | |
| "=================================================================\n", | |
| "Total params: 343,052\n", | |
| "Trainable params: 343,052\n", | |
| "Non-trainable params: 0\n", | |
| "_________________________________________________________________\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model.summary()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train on 480000 samples, validate on 120000 samples\n", | |
| "Epoch 1/1000\n", | |
| "157s - loss: 1.1147 - acc: 0.5823 - val_loss: 0.5931 - val_acc: 0.8039\n", | |
| "Epoch 2/1000\n", | |
| "146s - loss: 0.4039 - acc: 0.8568 - val_loss: 0.3069 - val_acc: 0.8849\n", | |
| "Epoch 3/1000\n", | |
| "142s - loss: 0.1948 - acc: 0.9316 - val_loss: 0.1232 - val_acc: 0.9608\n", | |
| "Epoch 4/1000\n", | |
| "142s - loss: 0.0904 - acc: 0.9708 - val_loss: 0.0693 - val_acc: 0.9774\n", | |
| "Epoch 5/1000\n", | |
| "142s - loss: 0.0559 - acc: 0.9816 - val_loss: 0.0418 - val_acc: 0.9864\n", | |
| "Epoch 6/1000\n", | |
| "142s - loss: 0.0373 - acc: 0.9880 - val_loss: 0.0304 - val_acc: 0.9900\n", | |
| "Epoch 7/1000\n", | |
| "142s - loss: 0.0280 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9941\n", | |
| "Epoch 8/1000\n", | |
| "142s - loss: 0.0210 - acc: 0.9935 - val_loss: 0.0241 - val_acc: 0.9919\n", | |
| "Epoch 9/1000\n", | |
| "142s - loss: 0.0181 - acc: 0.9946 - val_loss: 0.0103 - val_acc: 0.9970\n", | |
| "Epoch 10/1000\n", | |
| "141s - loss: 0.0137 - acc: 0.9960 - val_loss: 0.0108 - val_acc: 0.9967\n", | |
| "Epoch 11/1000\n", | |
| "141s - loss: 0.0143 - acc: 0.9959 - val_loss: 0.0148 - val_acc: 0.9958\n", | |
| "Epoch 12/1000\n", | |
| "141s - loss: 0.0114 - acc: 0.9968 - val_loss: 0.0046 - val_acc: 0.9988\n", | |
| "Epoch 13/1000\n", | |
| "141s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0067 - val_acc: 0.9980\n", | |
| "Epoch 14/1000\n", | |
| "141s - loss: 0.0076 - acc: 0.9978 - val_loss: 0.0048 - val_acc: 0.9986\n", | |
| "Epoch 15/1000\n", | |
| "141s - loss: 0.0093 - acc: 0.9975 - val_loss: 0.0081 - val_acc: 0.9975\n", | |
| "Epoch 16/1000\n", | |
| "141s - loss: 0.0073 - acc: 0.9979 - val_loss: 0.0080 - val_acc: 0.9975\n", | |
| "Epoch 17/1000\n", | |
| "141s - loss: 0.0059 - acc: 0.9983 - val_loss: 0.0043 - val_acc: 0.9987\n", | |
| "Epoch 18/1000\n", | |
| "141s - loss: 0.0068 - acc: 0.9981 - val_loss: 0.0046 - val_acc: 0.9986\n", | |
| "Epoch 19/1000\n", | |
| "141s - loss: 0.0058 - acc: 0.9984 - val_loss: 0.0044 - val_acc: 0.9987\n", | |
| "Epoch 20/1000\n", | |
| "141s - loss: 0.0064 - acc: 0.9982 - val_loss: 0.0094 - val_acc: 0.9972\n", | |
| "Epoch 21/1000\n", | |
| "141s - loss: 0.0053 - acc: 0.9985 - val_loss: 0.0039 - val_acc: 0.9989\n", | |
| "Epoch 22/1000\n", | |
| "141s - loss: 0.0041 - acc: 0.9989 - val_loss: 0.0042 - val_acc: 0.9987\n", | |
| "Epoch 23/1000\n", | |
| "141s - loss: 0.0050 - acc: 0.9986 - val_loss: 0.0172 - val_acc: 0.9949\n", | |
| "Epoch 24/1000\n", | |
| "141s - loss: 0.0038 - acc: 0.9989 - val_loss: 0.0033 - val_acc: 0.9990\n", | |
| "Epoch 25/1000\n", | |
| "141s - loss: 0.0051 - acc: 0.9987 - val_loss: 0.0020 - val_acc: 0.9995\n", | |
| "Epoch 26/1000\n", | |
| "142s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0023 - val_acc: 0.9994\n", | |
| "Epoch 27/1000\n", | |
| "141s - loss: 0.0044 - acc: 0.9988 - val_loss: 0.0018 - val_acc: 0.9995\n", | |
| "Epoch 28/1000\n", | |
| "141s - loss: 0.0032 - acc: 0.9991 - val_loss: 0.0029 - val_acc: 0.9992\n", | |
| "Epoch 29/1000\n", | |
| "144s - loss: 0.0042 - acc: 0.9988 - val_loss: 0.0085 - val_acc: 0.9974\n", | |
| "Epoch 30/1000\n", | |
| "144s - loss: 0.0033 - acc: 0.9991 - val_loss: 0.0019 - val_acc: 0.9995\n", | |
| "Epoch 31/1000\n", | |
| "157s - loss: 0.0039 - acc: 0.9990 - val_loss: 0.0014 - val_acc: 0.9997\n", | |
| "Epoch 32/1000\n", | |
| "154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0033 - val_acc: 0.9991\n", | |
| "Epoch 33/1000\n", | |
| "150s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0013 - val_acc: 0.9997\n", | |
| "Epoch 34/1000\n", | |
| "154s - loss: 0.0028 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9993\n", | |
| "Epoch 35/1000\n", | |
| "152s - loss: 0.0032 - acc: 0.9992 - val_loss: 0.0038 - val_acc: 0.9988\n", | |
| "Epoch 36/1000\n", | |
| "157s - loss: 0.0026 - acc: 0.9993 - val_loss: 0.0037 - val_acc: 0.9989\n", | |
| "Epoch 37/1000\n", | |
| "152s - loss: 0.0024 - acc: 0.9993 - val_loss: 0.0016 - val_acc: 0.9996\n", | |
| "Epoch 38/1000\n", | |
| "153s - loss: 0.0031 - acc: 0.9992 - val_loss: 0.0024 - val_acc: 0.9992\n", | |
| "Epoch 39/1000\n", | |
| "155s - loss: 0.0025 - acc: 0.9993 - val_loss: 0.0031 - val_acc: 0.9990\n", | |
| "Epoch 00038: early stopping\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<keras.callbacks.History at 0x1a2ad1d6e48>" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.fit(\n", | |
| " training_x, training_y,\n", | |
| " batch_size=BATCH_SIZE,\n", | |
| " epochs=MAX_N_EPOCS,\n", | |
| " verbose=2,\n", | |
| " validation_split=.2,\n", | |
| " callbacks=[\n", | |
| " keras.callbacks.EarlyStopping(patience=5, verbose=2),\n", | |
| " ],\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "100000/100000 [==============================] - 32s \n", | |
| "\n", | |
| "Test loss: 0.0030354137425270163\n", | |
| "Test acc: 0.9990433450508117\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "metrics_vals = model.evaluate(testing_x, testing_y)\n", | |
| "\n", | |
| "print('')\n", | |
| "for metric_name, metric_val in zip(model.metrics_names, metrics_vals):\n", | |
| " print('Test {}: {}'.format(metric_name, metric_val))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model In Action" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def neural_addition(x, y):\n", | |
| " input_, _ = generate_example(x, y)\n", | |
| " output_ = model.predict_on_batch(np.array([input_]))[0]\n", | |
| " indices = np.argmax(output_, axis=1)\n", | |
| " return ''.join(CHARS[index] for index in indices)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "163 + 0 = \"163 \" (correct)\n", | |
| "96 + 453 = \"549 \" (correct)\n", | |
| "69 + 557 = \"626 \" (correct)\n", | |
| "7721 + 98 = \"7819 \" (correct)\n", | |
| "5112 + 79646 = \"84758 \" (correct)\n", | |
| "493 + 43044 = \"43537 \" (correct)\n", | |
| "51 + 489 = \"540 \" (correct)\n", | |
| "84628 + 3457 = \"88085 \" (correct)\n", | |
| "1 + 2236 = \"2237 \" (correct)\n", | |
| "0 + 4622 = \"4622 \" (correct)\n", | |
| "67 + 0 = \"67 \" (correct)\n", | |
| "90642 + 68 = \"90710 \" (correct)\n", | |
| "6 + 6 = \"12 \" (correct)\n", | |
| "38973 + 23 = \"38996 \" (correct)\n", | |
| "4 + 5945 = \"5949 \" (correct)\n", | |
| "155 + 321 = \"476 \" (correct)\n", | |
| "4987 + 2805 = \"7792 \" (correct)\n", | |
| "70001 + 8 = \"70009 \" (correct)\n", | |
| "1085 + 36 = \"1121 \" (correct)\n", | |
| "13 + 2969 = \"2982 \" (correct)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for _ in range(20):\n", | |
| " x, y = generate_addend_pair()\n", | |
| " expected = x + y\n", | |
| " result = neural_addition(x, y)\n", | |
| " if result.strip() == str(expected):\n", | |
| " print('{} + {} = \"{}\" (correct)'.format(x, y, result))\n", | |
| " else:\n", | |
| " print('{} + {} = \"{}\" (incorrect, should be {})'.format(x, y, result, expected))" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.1" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment