Last active
December 11, 2018 02:06
-
-
Save j-min/481749dcb853b4477c4f441bf7452195 to your computer and use it in GitHub Desktop.
TensorFlow 0.9 implementation of BasicRNNCell based on hunkim's tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### BasicRNNCell\n", | |
| "#### TensorFlow 0.9 implementation based on hunkim's tutorial\n", | |
| "https://hunkim.github.io/ml/\n", | |
| "\n", | |
| "https://www.youtube.com/watch?v=A8wJYfDUYCk&feature=youtu.be" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "{'o': 3, 'l': 2, 'e': 1, 'h': 0}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "char_rdic = ['h', 'e', 'l', 'o'] # id -> char\n", | |
| "char_dic = {w : i for i, w in enumerate(char_rdic)} # char -> id\n", | |
| "print (char_dic)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0, 1, 2, 2, 3]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ground_truth = [char_dic[c] for c in 'hello']\n", | |
| "print (ground_truth)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "x_data = np.array([[1,0,0,0], # h\n", | |
| " [0,1,0,0], # e\n", | |
| " [0,0,1,0], # l\n", | |
| " [0,0,1,0]], # l\n", | |
| " dtype = 'f')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"one_hot:0\", shape=(4, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x_data = tf.one_hot(ground_truth[:-1], len(char_dic), 1.0, 0.0, -1)\n", | |
| "print(x_data)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Configuration\n", | |
| "rnn_size = len(char_dic) # 4\n", | |
| "batch_size = 1\n", | |
| "output_size = 4" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "<tensorflow.python.ops.rnn_cell.BasicRNNCell object at 0x7effb759c9e8>\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# RNN Model\n", | |
| "rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size,\n", | |
| " input_size = None, # deprecated at tensorflow 0.9\n", | |
| " #activation = tanh,\n", | |
| " )\n", | |
| "print(rnn_cell)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"zeros:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "initial_state = rnn_cell.zero_state(batch_size, tf.float32)\n", | |
| "print(initial_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tensor(\"zeros_1:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "initial_state_1 = tf.zeros([batch_size, rnn_cell.state_size]) # 위 코드와 같은 결과\n", | |
| "print(initial_state_1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[<tf.Tensor 'split:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:1' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:2' shape=(1, 4) dtype=float32>, <tf.Tensor 'split:3' shape=(1, 4) dtype=float32>]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'\\n[[1,0,0,0]] # h\\n[[0,1,0,0]] # e\\n[[0,0,1,0]] # l\\n[[0,0,1,0]] # l\\n'" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x_split = tf.split(0, len(char_dic), x_data) # 가로축으로 4개로 split\n", | |
| "print(x_split)\n", | |
| "\"\"\"\n", | |
| "[[1,0,0,0]] # h\n", | |
| "[[0,1,0,0]] # e\n", | |
| "[[0,0,1,0]] # l\n", | |
| "[[0,0,1,0]] # l\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "outputs, state = tf.nn.rnn(cell = rnn_cell, inputs = x_split, initial_state = initial_state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[<tf.Tensor 'RNN/BasicRNNCell/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_1/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_2/Tanh:0' shape=(1, 4) dtype=float32>, <tf.Tensor 'RNN/BasicRNNCell_3/Tanh:0' shape=(1, 4) dtype=float32>]\n", | |
| "Tensor(\"RNN/BasicRNNCell_3/Tanh:0\", shape=(1, 4), dtype=float32)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print (outputs)\n", | |
| "print (state)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'\\n[[logit from 1st output],\\n[logit from 2nd output],\\n[logit from 3rd output],\\n[logit from 4th output]]\\n'" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "logits = tf.reshape(tf.concat(1, outputs), # shape = 1 x 16\n", | |
| " [-1, rnn_size]) # shape = 4 x 4\n", | |
| "logits.get_shape()\n", | |
| "\"\"\"\n", | |
| "[[logit from 1st output],\n", | |
| "[logit from 2nd output],\n", | |
| "[logit from 3rd output],\n", | |
| "[logit from 4th output]]\n", | |
| "\"\"\"\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "TensorShape([Dimension(4)])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "targets = tf.reshape(ground_truth[1:], [-1]) # a shape of [-1] flattens into 1-D\n", | |
| "targets.get_shape()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "weights = tf.ones([len(char_dic) * batch_size])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [targets], [weights])\n", | |
| "cost = tf.reduce_sum(loss) / batch_size\n", | |
| "train_op = tf.train.RMSPropOptimizer(0.01, 0.9).minimize(cost)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 1 2 2] ['e', 'e', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 2] ['e', 'l', 'l', 'l']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n", | |
| "[1 2 2 3] ['e', 'l', 'l', 'o']\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Launch the graph in a session\n", | |
| "with tf.Session() as sess:\n", | |
| " tf.initialize_all_variables().run()\n", | |
| " for i in range(100):\n", | |
| " sess.run(train_op)\n", | |
| " result = sess.run(tf.argmax(logits, 1))\n", | |
| " print(result, [char_rdic[t] for t in result]) " | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "anaconda-cloud": {}, | |
| "kernelspec": { | |
| "display_name": "Python [tensorflow]", | |
| "language": "python", | |
| "name": "Python [tensorflow]" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Author
@goddoe 수정했습니다. 지적해주셔서 감사합니다!
좋은 예제 감사합니다. 위 예제를 버전 1.0에서 실행하기 위해서는 링크된 글을 참고하시길 바랍니다.
위에 모든분들 너무나 감사합니다. 혼자 실습중인데 너무나 오류나서 힘들어하고 있었는데 1.0버젼에 맞춰서 코딩수정까지 자료가 있으니 너무나 힘이 됩니다 ! 감사합니다 !
tensor flow 1.3이상 버전에서는
In [17]을
# Launch the graph in a session
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(100):
sess.run(train_op, )
result = sess.run(tf.argmax(logits, axis=1))
print(result, [char_rdic[t] for t in result])
로 해야 작동됩니다.
1.8.0. 버전에서
[10]
x_split = tf.split(0, len(char_dic), x_data) 을
x_split = tf.split(x_data, len(char_dic), 0)
으로 해야 돌아갑니다.
1.8.0에 맞게 수정해서 올렸습니다.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
4 번째 셀에
x_data = np.array([[1,0,0,0], # h
[0,1,0,0], # e
[0,0,1,0], # l
[0,0,0,1]], # l 이부분이 잘못된 것같습니다 [0,0,1,0] 로 되어야하지 않을까요
dtype = 'f')