Created
January 19, 2017 19:16
-
-
Save mehmetakifakkus/b841a94d609891eb34a21111609f4198 to your computer and use it in GitHub Desktop.
siamese
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ Siamese implementation using Tensorflow with MNIST example. | |
| This siamese network embeds a 28x28 image (a point in 784D) | |
| into a point in 2D. | |
| By Youngwook Paul Kwon (young at berkeley.edu) | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| #import system things | |
| from tensorflow.examples.tutorials.mnist import input_data # for data | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| #import helpers | |
| import inference | |
| import visualize | |
| # prepare data and tf.session | |
| # mnist = input_data.read_data_sets('MNIST_data', one_hot=False) | |
| sketch_test = np.load("test_set.npy") | |
| sketch_test_label = np.load("test_set_label.npy") | |
| sketch_training = np.load("train_set.npy") | |
| sketch_training_label = np.load("train_set_label.npy") | |
| sess = tf.InteractiveSession() | |
| # setup siamese network | |
| siamese = inference.siamese(); | |
| train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss) | |
| saver = tf.train.Saver() | |
| tf.initialize_all_variables().run() | |
| # if you just want to load a previously trainmodel? | |
| new = True | |
| model_ckpt = 'model.ckpt' | |
| if os.path.isfile(model_ckpt): | |
| input_var = None | |
| while input_var not in ['yes', 'no']: | |
| input_var = raw_input("We found model.ckpt file. Do you want to load it [yes/no]?") | |
| if input_var == 'yes': | |
| new = False | |
| def getBatch(length, num, i): | |
| return {'start':(length/num)*(i%num), 'end':(length/num)*((i+1)%num)-1, 'no': i%num, 'epoch':int(i/num)} | |
| # start training | |
| if new: | |
| i=0 | |
| for step in range(1100): | |
| #batch_x1, batch_y1 = mnist.train.next_batch(128) | |
| #batch_x2, batch_y2 = mnist.train.next_batch(128) | |
| #batch_y = (batch_y1 == batch_y2).astype('float') | |
| batch1 = getBatch(length, 5, i); | |
| batch2 = getBatch(length, 5, i); | |
| i++; | |
| batch_x1 = trainX[batch1['start']:batch1['end'], ] | |
| batch_y1 = trainY[batch1['start']:batch1['end'], ] | |
| batch_x2 = trainX[batch2['start']:batch2['end'], ] | |
| batch_y2 = trainY[batch2['start']:batch2['end'], ] | |
| batch_y = (batch_y1 == batch_y2).astype('float'); | |
| _, loss_v = sess.run([train_step, siamese.loss], feed_dict={ | |
| siamese.x1: batch_x1, | |
| siamese.x2: batch_x2, | |
| siamese.y_: batch_y}) | |
| if np.isnan(loss_v): | |
| print('Model diverged with loss = NaN') | |
| quit() | |
| if step % 10 == 0: | |
| print ('step %d: loss %.3f' % (step, loss_v)) | |
| if step % 1000 == 0 and step > 0: | |
| saver.save(sess, 'model.ckpt') | |
| embed = siamese.o1.eval({siamese.x1: mnist.test.images}) | |
| embed.tofile('embed.txt') | |
| else: | |
| saver.restore(sess, 'model.ckpt') | |
| # visualize result | |
| x_test = mnist.test.images.reshape([-1, 28, 28]) | |
| visualize.visualize(embed, x_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment