Skip to content

Instantly share code, notes, and snippets.

@rohanar
Last active April 17, 2020 14:21
Show Gist options
  • Select an option

  • Save rohanar/8b10e57743881a860099 to your computer and use it in GitHub Desktop.

Select an option

Save rohanar/8b10e57743881a860099 to your computer and use it in GitHub Desktop.
KMeansClustering Example
package org.deeplearning4j.examples.paragraphvectors;
import org.apache.commons.lang.time.StopWatch;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.cluster.PointClassification;
import org.deeplearning4j.clustering.kmeans.KMeansClustering;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;
import org.synthesis.java.extension.ParagraphVectorSerializer;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
/**
* This is example code for dl4j ParagraphVectors implementation. In this example we build distributed representation of all sentences present in training corpus.
* However, you still use it for training on labelled documents, using sets of LabelledDocument and LabelAwareIterator implementation.
*
* *************************************************************************************************
* PLEASE NOTE: THIS EXAMPLE REQUIRES DL4J/ND4J VERSIONS >= rc3.8 TO COMPILE SUCCESSFULLY
* *************************************************************************************************
*
* @author raver119@gmail.com
*/
public class ParagraphVectorsTextKMeansExample {
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsTextKMeansExample.class);
public static void main(String[] args) throws Exception {
String datafilepath = "/raw_sentences.txt"; //This has ~99000 single-sentence paragraphs/docs
ClassPathResource resource = new ClassPathResource(datafilepath);
File file = resource.getFile();
SentenceIterator iter = new BasicLineIterator(file);
InMemoryLookupCache cache = new InMemoryLookupCache();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
StopWatch sw = new StopWatch();
/*
if you don't have LabelAwareIterator handy, you can use synchronized labels generator
it will be used to label each document/sequence/line with it's own label.
But if you have LabelAwareIterator ready, you can can provide it, for your in-house labels
*/
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder()
.minWordFrequency(1)
.iterations(3)
.epochs(1)
.layerSize(100) /*length of a paragraph vector*/
.learningRate(0.025)
.labelsSource(source)
.windowSize(5)
.iterate(iter)
.trainWordVectors(false)
.vocabCache(cache)
.tokenizerFactory(t)
.sampling(0)
.build();
vec.fit();
//1. create a kmeanscluster instance
int maxIterationCount = 5;
int clusterCount = 10;
String distanceFunction = "cosinesimilarity";
KMeansClustering kmc = KMeansClustering.setup(clusterCount, maxIterationCount, distanceFunction);
//2. iterate over rows in the paragraphvector and create a List of paragraph vectors
List<INDArray> vectors = new ArrayList<INDArray>();
for (String word : vec.vocab().words()) {
vectors.add(vec.getWordVectorMatrix(word));
}
log.info(vectors.size() + " vectors extracted to create Point list");
List<Point> pointsLst = Point.toPoints(vectors);
log.info(pointsLst.size() + " Points created out of " + vectors.size() + " vectors");
log.info("Start Clustering " + pointsLst.size() + " points/docs");
sw.reset();
sw.start();
ClusterSet cs = kmc.applyTo(pointsLst);
sw.stop();
System.out.println("Time taken to run clustering on " + vectors.size() + " paragraphVectors: " + sw.getTime());
vectors = null;
pointsLst = null;
log.info("Finish Clustering");
List<Cluster> clsterLst = cs.getClusters();
System.out.println("\nCluster Centers:");
for(Cluster c: clsterLst) {
Point center = c.getCenter();
System.out.println(center.getId());
}
log.info("Trying to classify a point that was used for generating the Clusters");
double[] nesVec = vec.getWordVector("DOC_400");
Point newpoint = new Point("myid", "mylabel", nesVec);
PointClassification pc = cs.classifyPoint(newpoint);
System.out.println(pc.getCluster().getCenter().getId());
System.out.println("\nEnd Test");
}
}
@karthikrao21
Copy link

Thanks for the example. I tried something similar with Word2Vec. I got clusters. However, I could only get IDs (random hash values?), and I could not get the text of the words themselves. I tried center.getLabel().
Clearly, I need to learn more about these objects. Where can I find documentation that is more detailed than the javadocs?
Thanks, regards
Karthik

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment