Skip to content

Instantly share code, notes, and snippets.

@rohanar
Created February 9, 2016 10:52
Show Gist options
  • Select an option

  • Save rohanar/5ee3d6abd4f50717f5a3 to your computer and use it in GitHub Desktop.

Select an option

Save rohanar/5ee3d6abd4f50717f5a3 to your computer and use it in GitHub Desktop.
ParagraphVectors learning Test code
package org.deeplearning4j.examples.paragraphvectors;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.time.StopWatch;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.synthesis.java.extension.SolrDocLabelAwareIterator2;
/**
* This is example code for dl4j ParagraphVectors implementation. In this example we build distributed representation of all sentences present in training corpus.
* However, you still use it for training on labelled documents, using sets of LabelledDocument and LabelAwareIterator implementation.
*
* *************************************************************************************************
* PLEASE NOTE: THIS EXAMPLE REQUIRES DL4J/ND4J VERSIONS >= rc3.8 TO COMPILE SUCCESSFULLY
* *************************************************************************************************
*
* @author raver119@gmail.com
*/
public class ParagraphVectorLearningTest {
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorLearningTest.class);
public static void main(String[] args) throws Exception {
// Manually create few test documents
Map<String,String> docContentsMap = new HashMap<String,String>();
docContentsMap.put("MYDOC_1","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable ");
docContentsMap.put("MYDOC_2","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable ");
docContentsMap.put("MYDOC_3","We have compiled a list of frequently asked questions from residents and made them available online. Enter your question below and click search. If you don't find the answer to your question you can submit it for us to answer.");
docContentsMap.put("MYDOC_4","Some general tips and tricks. Check the for general tips and tricks when creating an article. Additional Downloads");
String paraVecMdlFile = "mandocs" + docContentsMap.size() + ".txt";
//Vector Learning-related Settings
boolean learnParaVecs = true; //if set to false, pre-trained model will be loaded
int minWordFrequency = 1;
int wordLearnIterations = 100;
int epochs = 10; //no of training epochs
int layerSize = 10; /*length of a word/paragraph vector*/
double lr = 0.01; //0.025
//learn
ParagraphVectors vec = null;
StopWatch st = new StopWatch();
if(learnParaVecs) {
vec = learnParagraphVectors(docContentsMap, paraVecMdlFile, minWordFrequency, wordLearnIterations, epochs, layerSize, lr);
} else {
st.reset();
st.start();
vec = WordVectorSerializer.readParagraphVectorsFromText(paraVecMdlFile);
st.stop();
System.out.println("Time taken for reading paragraphVectors from disk: " + st.getTime() + "ms");
}
double sim = vec.similarity("MYDOC_1", "MYDOC_2");
log.info("MYDOC_1/MYDOC_2 similarity: " + sim);
printParagraphVector("MYDOC_1", vec);
printParagraphVector("MYDOC_2", vec);
System.out.println("\nEnd Test");
} //end main()
//==================Utility methods==============
private static ParagraphVectors learnParagraphVectors(Map<String,String> docContentsMap, String serialize2file,
int minWordFrequency, int wordLearnIterations, int epochs, int layerSize, double lr) throws IOException {
LabelsSource source = new LabelsSource();
// build a iterator for our dataset
SolrDocLabelAwareIterator2 iterator = new SolrDocLabelAwareIterator2.Builder()
.build(docContentsMap);
InMemoryLookupCache cache = new InMemoryLookupCache();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
StopWatch sw = new StopWatch();
ParagraphVectors vec = new ParagraphVectors.Builder()
.minWordFrequency(minWordFrequency)
.iterations(wordLearnIterations)
.epochs(epochs)
.layerSize(layerSize) /*length of a paragraph vector*/
.learningRate(lr)
.labelsSource(source)
.windowSize(5)
.iterate(iterator)
.trainWordVectors(true)
.vocabCache(cache)
.tokenizerFactory(t)
.sampling(0)
.build();
sw.start();
vec.fit();
sw.stop();
System.out.println("Time taken to learn ParagraphVectors for documents is " + sw.getTime() + "ms");
//Serialising
if(serialize2file != null && !serialize2file.isEmpty()) {
WordVectorSerializer.writeWordVectors(vec, serialize2file);
}
return vec;
}
/**Print a paragraphVector */
private static void printParagraphVector(String docid, ParagraphVectors vec) {
if(vec.hasWord(docid)) {
double[] V_city = vec.getWordVector(docid);
System.out.print("\nVector of " + docid + ": " );
for(int i=0; i< V_city.length; i++) {
System.out.print(V_city[i] + " ");
}
System.out.println();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment