Created
February 9, 2016 10:52
-
-
Save rohanar/5ee3d6abd4f50717f5a3 to your computer and use it in GitHub Desktop.
ParagraphVectors learning Test code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.deeplearning4j.examples.paragraphvectors; | |
| import java.io.IOException; | |
| import java.util.HashMap; | |
| import java.util.Map; | |
| import org.apache.commons.lang.time.StopWatch; | |
| import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | |
| import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
| import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache; | |
| import org.deeplearning4j.text.documentiterator.LabelsSource; | |
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import org.synthesis.java.extension.SolrDocLabelAwareIterator2; | |
| /** | |
| * This is example code for dl4j ParagraphVectors implementation. In this example we build distributed representation of all sentences present in training corpus. | |
| * However, you still use it for training on labelled documents, using sets of LabelledDocument and LabelAwareIterator implementation. | |
| * | |
| * ************************************************************************************************* | |
| * PLEASE NOTE: THIS EXAMPLE REQUIRES DL4J/ND4J VERSIONS >= rc3.8 TO COMPILE SUCCESSFULLY | |
| * ************************************************************************************************* | |
| * | |
| * @author raver119@gmail.com | |
| */ | |
| public class ParagraphVectorLearningTest { | |
| private static final Logger log = LoggerFactory.getLogger(ParagraphVectorLearningTest.class); | |
| public static void main(String[] args) throws Exception { | |
| // Manually create few test documents | |
| Map<String,String> docContentsMap = new HashMap<String,String>(); | |
| docContentsMap.put("MYDOC_1","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable "); | |
| docContentsMap.put("MYDOC_2","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable "); | |
| docContentsMap.put("MYDOC_3","We have compiled a list of frequently asked questions from residents and made them available online. Enter your question below and click search. If you don't find the answer to your question you can submit it for us to answer."); | |
| docContentsMap.put("MYDOC_4","Some general tips and tricks. Check the for general tips and tricks when creating an article. Additional Downloads"); | |
| String paraVecMdlFile = "mandocs" + docContentsMap.size() + ".txt"; | |
| //Vector Learning-related Settings | |
| boolean learnParaVecs = true; //if set to false, pre-trained model will be loaded | |
| int minWordFrequency = 1; | |
| int wordLearnIterations = 100; | |
| int epochs = 10; //no of training epochs | |
| int layerSize = 10; /*length of a word/paragraph vector*/ | |
| double lr = 0.01; //0.025 | |
| //learn | |
| ParagraphVectors vec = null; | |
| StopWatch st = new StopWatch(); | |
| if(learnParaVecs) { | |
| vec = learnParagraphVectors(docContentsMap, paraVecMdlFile, minWordFrequency, wordLearnIterations, epochs, layerSize, lr); | |
| } else { | |
| st.reset(); | |
| st.start(); | |
| vec = WordVectorSerializer.readParagraphVectorsFromText(paraVecMdlFile); | |
| st.stop(); | |
| System.out.println("Time taken for reading paragraphVectors from disk: " + st.getTime() + "ms"); | |
| } | |
| double sim = vec.similarity("MYDOC_1", "MYDOC_2"); | |
| log.info("MYDOC_1/MYDOC_2 similarity: " + sim); | |
| printParagraphVector("MYDOC_1", vec); | |
| printParagraphVector("MYDOC_2", vec); | |
| System.out.println("\nEnd Test"); | |
| } //end main() | |
| //==================Utility methods============== | |
| private static ParagraphVectors learnParagraphVectors(Map<String,String> docContentsMap, String serialize2file, | |
| int minWordFrequency, int wordLearnIterations, int epochs, int layerSize, double lr) throws IOException { | |
| LabelsSource source = new LabelsSource(); | |
| // build a iterator for our dataset | |
| SolrDocLabelAwareIterator2 iterator = new SolrDocLabelAwareIterator2.Builder() | |
| .build(docContentsMap); | |
| InMemoryLookupCache cache = new InMemoryLookupCache(); | |
| TokenizerFactory t = new DefaultTokenizerFactory(); | |
| t.setTokenPreProcessor(new CommonPreprocessor()); | |
| StopWatch sw = new StopWatch(); | |
| ParagraphVectors vec = new ParagraphVectors.Builder() | |
| .minWordFrequency(minWordFrequency) | |
| .iterations(wordLearnIterations) | |
| .epochs(epochs) | |
| .layerSize(layerSize) /*length of a paragraph vector*/ | |
| .learningRate(lr) | |
| .labelsSource(source) | |
| .windowSize(5) | |
| .iterate(iterator) | |
| .trainWordVectors(true) | |
| .vocabCache(cache) | |
| .tokenizerFactory(t) | |
| .sampling(0) | |
| .build(); | |
| sw.start(); | |
| vec.fit(); | |
| sw.stop(); | |
| System.out.println("Time taken to learn ParagraphVectors for documents is " + sw.getTime() + "ms"); | |
| //Serialising | |
| if(serialize2file != null && !serialize2file.isEmpty()) { | |
| WordVectorSerializer.writeWordVectors(vec, serialize2file); | |
| } | |
| return vec; | |
| } | |
| /**Print a paragraphVector */ | |
| private static void printParagraphVector(String docid, ParagraphVectors vec) { | |
| if(vec.hasWord(docid)) { | |
| double[] V_city = vec.getWordVector(docid); | |
| System.out.print("\nVector of " + docid + ": " ); | |
| for(int i=0; i< V_city.length; i++) { | |
| System.out.print(V_city[i] + " "); | |
| } | |
| System.out.println(); | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment