Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save mehmetakifakkus/b841a94d609891eb34a21111609f4198 to your computer and use it in GitHub Desktop.

Select an option

Save mehmetakifakkus/b841a94d609891eb34a21111609f4198 to your computer and use it in GitHub Desktop.
siamese
""" Siamese implementation using Tensorflow with MNIST example.
This siamese network embeds a 28x28 image (a point in 784D)
into a point in 2D.
By Youngwook Paul Kwon (young at berkeley.edu)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#import system things
from tensorflow.examples.tutorials.mnist import input_data # for data
import tensorflow as tf
import numpy as np
import os
#import helpers
import inference
import visualize
# prepare data and tf.session
# mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
sketch_test = np.load("test_set.npy")
sketch_test_label = np.load("test_set_label.npy")
sketch_training = np.load("train_set.npy")
sketch_training_label = np.load("train_set_label.npy")
sess = tf.InteractiveSession()
# setup siamese network
siamese = inference.siamese();
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss)
saver = tf.train.Saver()
tf.initialize_all_variables().run()
# if you just want to load a previously trainmodel?
new = True
model_ckpt = 'model.ckpt'
if os.path.isfile(model_ckpt):
input_var = None
while input_var not in ['yes', 'no']:
input_var = raw_input("We found model.ckpt file. Do you want to load it [yes/no]?")
if input_var == 'yes':
new = False
def getBatch(length, num, i):
return {'start':(length/num)*(i%num), 'end':(length/num)*((i+1)%num)-1, 'no': i%num, 'epoch':int(i/num)}
# start training
if new:
i=0
for step in range(1100):
#batch_x1, batch_y1 = mnist.train.next_batch(128)
#batch_x2, batch_y2 = mnist.train.next_batch(128)
#batch_y = (batch_y1 == batch_y2).astype('float')
batch1 = getBatch(length, 5, i);
batch2 = getBatch(length, 5, i);
i++;
batch_x1 = trainX[batch1['start']:batch1['end'], ]
batch_y1 = trainY[batch1['start']:batch1['end'], ]
batch_x2 = trainX[batch2['start']:batch2['end'], ]
batch_y2 = trainY[batch2['start']:batch2['end'], ]
batch_y = (batch_y1 == batch_y2).astype('float');
_, loss_v = sess.run([train_step, siamese.loss], feed_dict={
siamese.x1: batch_x1,
siamese.x2: batch_x2,
siamese.y_: batch_y})
if np.isnan(loss_v):
print('Model diverged with loss = NaN')
quit()
if step % 10 == 0:
print ('step %d: loss %.3f' % (step, loss_v))
if step % 1000 == 0 and step > 0:
saver.save(sess, 'model.ckpt')
embed = siamese.o1.eval({siamese.x1: mnist.test.images})
embed.tofile('embed.txt')
else:
saver.restore(sess, 'model.ckpt')
# visualize result
x_test = mnist.test.images.reshape([-1, 28, 28])
visualize.visualize(embed, x_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment