Created
February 18, 2016 11:47
-
-
Save rohanar/5bcb51e92eac3049fb83 to your computer and use it in GitHub Desktop.
ParagraphVectors test for stopwords usage
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.deeplearning4j.examples.paragraphvectors; | |
| import java.io.IOException; | |
| import java.util.ArrayList; | |
| import java.util.Collection; | |
| import java.util.HashMap; | |
| import java.util.List; | |
| import java.util.Map; | |
| import org.apache.commons.lang.time.StopWatch; | |
| import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | |
| import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
| import org.deeplearning4j.models.word2vec.VocabWord; | |
| import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache; | |
| import org.deeplearning4j.text.documentiterator.LabelsSource; | |
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import org.synthesis.java.extension.SolrDocLabelAwareIterator2; | |
| /** | |
| * This is example code for dl4j ParagraphVectors implementation. In this example we build distributed representation of all sentences present in training corpus. | |
| * However, you still use it for training on labelled documents, using sets of LabelledDocument and LabelAwareIterator implementation. | |
| * | |
| * ************************************************************************************************* | |
| * PLEASE NOTE: THIS EXAMPLE REQUIRES DL4J/ND4J VERSIONS >= rc3.8 TO COMPILE SUCCESSFULLY | |
| * ************************************************************************************************* | |
| * | |
| * @author raver119@gmail.com | |
| */ | |
| public class ParagraphVectorLearningTest2 { | |
| private static final Logger log = LoggerFactory.getLogger(ParagraphVectorLearningTest2.class); | |
| public static void main(String[] args) throws Exception { | |
| // Manually create few test documents | |
| Map<String,String> docContentsMap = new HashMap<String,String>(); | |
| docContentsMap.put("MYDOC_1","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable "); | |
| docContentsMap.put("MYDOC_2","An article using the Redirect template may be used to place links within navigational elements where they would otherwise be unsupported or require different text than the standard. Maturity: Production/Stable "); | |
| docContentsMap.put("MYDOC_3","We have compiled a list of frequently asked questions from residents and made them available online. Enter your question below and click search. If you don't find the answer to your question you can submit it for us to answer."); | |
| docContentsMap.put("MYDOC_4","Some general tips and tricks. Check the for general tips and tricks when creating an article. Additional Downloads"); | |
| String paraVecMdlFile = "mandocs" + docContentsMap.size() + ".txt"; | |
| //Vector Learning-related Settings | |
| boolean learnParaVecs = true; //if set to false, pre-trained model will be loaded | |
| int minWordFrequency = 1; | |
| int wordLearnIterations = 100; //iterations are over batches of paragraphs/docs | |
| int epochs = 10; //no of training epochs | |
| int layerSize = 200; /*length of a word/paragraph vector*/ | |
| double lr = 0.025; //0.025 | |
| //learn | |
| ParagraphVectors vec = null; | |
| StopWatch st = new StopWatch(); | |
| if(learnParaVecs) { | |
| vec = learnParagraphVectors(docContentsMap, paraVecMdlFile, minWordFrequency, wordLearnIterations, epochs, layerSize, lr); | |
| } else { | |
| st.reset(); | |
| st.start(); | |
| vec = WordVectorSerializer.readParagraphVectorsFromText(paraVecMdlFile); | |
| st.stop(); | |
| System.out.println("Time taken for reading paragraphVectors from disk: " + st.getTime() + "ms"); | |
| } | |
| printParagraphVector("the", vec); | |
| //Check the vocabulary | |
| printVocabulary2( vec); | |
| System.out.println("\nEnd Test"); | |
| } //end main() | |
| //==================Utility methods============== | |
| //print vocabulary | |
| private static void printVocabulary2(ParagraphVectors vec) { | |
| Collection<VocabWord> vocab = vec.getVocab().tokens(); | |
| for(VocabWord w: vocab) { | |
| //System.out.println(w.getWord() + " " + w.getLabel()); | |
| System.out.println(w.getWord()); | |
| } | |
| } | |
| private static ParagraphVectors learnParagraphVectors(Map<String,String> docContentsMap, String serialize2file, | |
| int minWordFrequency, int wordLearnIterations, int epochs, int layerSize, double lr) throws IOException { | |
| LabelsSource source = new LabelsSource(); | |
| // build a iterator for our dataset | |
| SolrDocLabelAwareIterator2 iterator = new SolrDocLabelAwareIterator2.Builder() | |
| .build(docContentsMap); | |
| InMemoryLookupCache cache = new InMemoryLookupCache(); | |
| TokenizerFactory t = new DefaultTokenizerFactory(); | |
| t.setTokenPreProcessor(new CommonPreprocessor()); | |
| List<String> stopList = new ArrayList<String>(); | |
| stopList.add("the"); | |
| stopList.add("after"); | |
| System.out.println("Stop Words: " + stopList); | |
| StopWatch sw = new StopWatch(); | |
| ParagraphVectors vec = new ParagraphVectors.Builder() | |
| .minWordFrequency(minWordFrequency) | |
| .stopWords(stopList) | |
| .iterations(wordLearnIterations) | |
| .epochs(epochs) | |
| .layerSize(layerSize) /*length of a paragraph vector*/ | |
| .learningRate(lr) | |
| .labelsSource(source) | |
| .windowSize(5) | |
| .iterate(iterator) | |
| .trainWordVectors(true) | |
| .vocabCache(cache) | |
| .tokenizerFactory(t) | |
| .sampling(0) | |
| .build(); | |
| sw.start(); | |
| vec.fit(); | |
| sw.stop(); | |
| System.out.println("Time taken to learn ParagraphVectors for documents is " + sw.getTime() + "ms"); | |
| //Serialising | |
| if(serialize2file != null && !serialize2file.isEmpty()) { | |
| WordVectorSerializer.writeWordVectors(vec, serialize2file); | |
| } | |
| return vec; | |
| } | |
| /**Print a paragraphVector */ | |
| private static void printParagraphVector(String token, ParagraphVectors vec) { | |
| if(vec.hasWord(token)) { | |
| double[] vdata = vec.getWordVector(token); | |
| System.out.print("\nVector of " + token + ": " ); | |
| for(int i=0; i< vdata.length; i++) { | |
| System.out.print(vdata[i] + " "); | |
| } | |
| System.out.println(); | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment