Skip to content

Instantly share code, notes, and snippets.

@rohanar
Last active April 17, 2020 14:21
Show Gist options
  • Select an option

  • Save rohanar/8b10e57743881a860099 to your computer and use it in GitHub Desktop.

Select an option

Save rohanar/8b10e57743881a860099 to your computer and use it in GitHub Desktop.
KMeansClustering Example
package org.deeplearning4j.examples.paragraphvectors;
import org.apache.commons.lang.time.StopWatch;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.cluster.PointClassification;
import org.deeplearning4j.clustering.kmeans.KMeansClustering;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;
import org.synthesis.java.extension.ParagraphVectorSerializer;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
/**
* This is example code for dl4j ParagraphVectors implementation. In this example we build distributed representation of all sentences present in training corpus.
* However, you still use it for training on labelled documents, using sets of LabelledDocument and LabelAwareIterator implementation.
*
* *************************************************************************************************
* PLEASE NOTE: THIS EXAMPLE REQUIRES DL4J/ND4J VERSIONS >= rc3.8 TO COMPILE SUCCESSFULLY
* *************************************************************************************************
*
* @author raver119@gmail.com
*/
public class ParagraphVectorsTextKMeansExample {
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsTextKMeansExample.class);
public static void main(String[] args) throws Exception {
String datafilepath = "/raw_sentences.txt"; //This has ~99000 single-sentence paragraphs/docs
ClassPathResource resource = new ClassPathResource(datafilepath);
File file = resource.getFile();
SentenceIterator iter = new BasicLineIterator(file);
InMemoryLookupCache cache = new InMemoryLookupCache();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
StopWatch sw = new StopWatch();
/*
if you don't have LabelAwareIterator handy, you can use synchronized labels generator
it will be used to label each document/sequence/line with it's own label.
But if you have LabelAwareIterator ready, you can can provide it, for your in-house labels
*/
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder()
.minWordFrequency(1)
.iterations(3)
.epochs(1)
.layerSize(100) /*length of a paragraph vector*/
.learningRate(0.025)
.labelsSource(source)
.windowSize(5)
.iterate(iter)
.trainWordVectors(false)
.vocabCache(cache)
.tokenizerFactory(t)
.sampling(0)
.build();
vec.fit();
//1. create a kmeanscluster instance
int maxIterationCount = 5;
int clusterCount = 10;
String distanceFunction = "cosinesimilarity";
KMeansClustering kmc = KMeansClustering.setup(clusterCount, maxIterationCount, distanceFunction);
//2. iterate over rows in the paragraphvector and create a List of paragraph vectors
List<INDArray> vectors = new ArrayList<INDArray>();
for (String word : vec.vocab().words()) {
vectors.add(vec.getWordVectorMatrix(word));
}
log.info(vectors.size() + " vectors extracted to create Point list");
List<Point> pointsLst = Point.toPoints(vectors);
log.info(pointsLst.size() + " Points created out of " + vectors.size() + " vectors");
log.info("Start Clustering " + pointsLst.size() + " points/docs");
sw.reset();
sw.start();
ClusterSet cs = kmc.applyTo(pointsLst);
sw.stop();
System.out.println("Time taken to run clustering on " + vectors.size() + " paragraphVectors: " + sw.getTime());
vectors = null;
pointsLst = null;
log.info("Finish Clustering");
List<Cluster> clsterLst = cs.getClusters();
System.out.println("\nCluster Centers:");
for(Cluster c: clsterLst) {
Point center = c.getCenter();
System.out.println(center.getId());
}
log.info("Trying to classify a point that was used for generating the Clusters");
double[] nesVec = vec.getWordVector("DOC_400");
Point newpoint = new Point("myid", "mylabel", nesVec);
PointClassification pc = cs.classifyPoint(newpoint);
System.out.println(pc.getCluster().getCenter().getId());
System.out.println("\nEnd Test");
}
}
@rohanar
Copy link
Author

rohanar commented Feb 1, 2016

This code seems to deadlock sometimes. When it runs it complete clustering within 2mins. It does deadlock even when 2000 lines
of text were in the input data file.When it does not complete, it hangs on with keep CPU 100% busy.
Also, in Eclipse, the execution does not seem to terminate even when the code completes clustering. You can see that the code
has executed the last (debug) line, but the little "red box" in Eclipse that lets you turminate a running code remains active/red.

@karthikrao21
Copy link

Thanks for the example. I tried something similar with Word2Vec. I got clusters. However, I could only get IDs (random hash values?), and I could not get the text of the words themselves. I tried center.getLabel().
Clearly, I need to learn more about these objects. Where can I find documentation that is more detailed than the javadocs?
Thanks, regards
Karthik

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment