Package cc.mallet.topics
Class ParallelTopicModel
- java.lang.Object
-
- cc.mallet.topics.ParallelTopicModel
-
- All Implemented Interfaces:
java.io.Serializable
- Direct Known Subclasses:
DMRTopicModel
,RTopicModel
public class ParallelTopicModel extends java.lang.Object implements java.io.Serializable
Simple parallel threaded implementation of LDA, following Newman, Asuncion, Smyth and Welling, Distributed Algorithms for Topic Models JMLR (2009), with SparseLDA sampling scheme and data structure from Yao, Mimno and McCallum, Efficient Methods for Topic Model Inference on Streaming Document Collections, KDD (2009).- Author:
- David Mimno, Andrew McCallum
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description double[]
alpha
Alphabet
alphabet
double
alphaSum
double
beta
double
betaSum
int
burninPeriod
java.util.ArrayList<TopicAssignment>
data
static double
DEFAULT_BETA
int[]
docLengthCounts
java.text.NumberFormat
formatter
static java.util.logging.Logger
logger
java.lang.String
modelFilename
int
numIterations
int
numTopics
int
numTypes
int
optimizeInterval
boolean
printLogLikelihood
int
randomSeed
int
saveModelInterval
int
saveSampleInterval
int
saveStateInterval
int
showTopicsInterval
java.lang.String
stateFilename
int
temperingInterval
int[]
tokensPerTopic
LabelAlphabet
topicAlphabet
int
topicBits
int[][]
topicDocCounts
int
topicMask
long
totalTokens
int[][]
typeTopicCounts
static int
UNASSIGNED_TOPIC
boolean
usingSymmetricAlpha
int
wordsPerTopic
-
Constructor Summary
Constructors Constructor Description ParallelTopicModel(int numberOfTopics)
ParallelTopicModel(int numberOfTopics, double alphaSum, double beta)
ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
addInstances(InstanceList training)
void
buildInitialTypeTopicCounts()
java.lang.String
displayTopWords(int numWords, boolean usingNewLines)
void
estimate()
Alphabet
getAlphabet()
java.util.ArrayList<TopicAssignment>
getData()
double[][]
getDocumentTopics(boolean normalized, boolean smoothed)
TopicInferencer
getInferencer()
Return a tool for estimating topic distributions for new documentsint
getNumTopics()
MarginalProbEstimator
getProbEstimator()
Return a tool for evaluating the marginal probability of new documents under this modeljava.util.ArrayList<java.util.TreeSet<IDSorter>>
getSortedWords()
Return an array of sorted sets (one set per topic).double[][]
getSubCorpusTopicWords(boolean[] documentMask, boolean normalized, boolean smoothed)
int[]
getTokensPerTopic()
LabelAlphabet
getTopicAlphabet()
java.util.ArrayList<java.util.TreeSet<IDSorter>>
getTopicDocuments(double smoothing)
double[]
getTopicProbabilities(int instanceID)
Get the smoothed distribution over topics for a training instance.double[]
getTopicProbabilities(LabelSequence topics)
Get the smoothed distribution over topics for a topic sequence, which may be from the training set or from a new instance with topics assigned by an inferencer.double[][]
getTopicWords(boolean normalized, boolean smoothed)
java.lang.Object[][]
getTopWords(int numWords)
Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order.int[][]
getTypeTopicCounts()
void
initializeFromState(java.io.File stateFile)
void
maximize(int iterations)
This method implements iterated conditional modes, which is equivalent to Gibbs sampling, but replacing sampling from the conditional distribution with taking the maximum topic.double
modelLogLikelihood()
void
optimizeAlpha(WorkerCallable[] callables)
void
optimizeBeta(WorkerCallable[] callables)
void
printDenseDocumentTopics(java.io.PrintWriter out)
void
printDocumentTopics(java.io.File file)
void
printDocumentTopics(java.io.PrintWriter out)
void
printDocumentTopics(java.io.PrintWriter out, double threshold, int max)
void
printState(java.io.File f)
void
printState(java.io.PrintStream out)
void
printTopicDocuments(java.io.PrintWriter out)
void
printTopicDocuments(java.io.PrintWriter out, int max)
void
printTopicWordWeights(java.io.File file)
void
printTopicWordWeights(java.io.PrintWriter out)
Print an unnormalized weight for every word in every topic.void
printTopWords(java.io.File file, int numWords, boolean useNewLines)
void
printTopWords(java.io.PrintStream out, int numWords, boolean usingNewLines)
void
printTypeTopicCounts(java.io.File file)
Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.static ParallelTopicModel
read(java.io.File f)
void
setBurninPeriod(int burninPeriod)
void
setNumIterations(int numIterations)
void
setNumThreads(int threads)
void
setNumTopics(int numTopics)
Set or reset the number of topics.void
setOptimizeInterval(int interval)
Interval for optimizing Dirichlet hyperparametersvoid
setRandomSeed(int seed)
void
setSaveSerializedModel(int interval, java.lang.String filename)
Define how often and where to save a serialized model.void
setSaveState(int interval, java.lang.String filename)
Define how often and where to save a text representation of the current state.void
setSymmetricAlpha(boolean b)
void
setTemperingInterval(int interval)
void
setTopicDisplay(int interval, int n)
void
temperAlpha(WorkerCallable[] callables)
void
topicPhraseXMLReport(java.io.PrintWriter out, int numWords)
void
topicXMLReport(java.io.PrintWriter out, int numWords)
void
write(java.io.File serializedModelFile)
-
-
-
Field Detail
-
UNASSIGNED_TOPIC
public static final int UNASSIGNED_TOPIC
- See Also:
- Constant Field Values
-
logger
public static java.util.logging.Logger logger
-
data
public java.util.ArrayList<TopicAssignment> data
-
alphabet
public Alphabet alphabet
-
topicAlphabet
public LabelAlphabet topicAlphabet
-
numTopics
public int numTopics
-
topicMask
public int topicMask
-
topicBits
public int topicBits
-
numTypes
public int numTypes
-
totalTokens
public long totalTokens
-
alpha
public double[] alpha
-
alphaSum
public double alphaSum
-
beta
public double beta
-
betaSum
public double betaSum
-
usingSymmetricAlpha
public boolean usingSymmetricAlpha
-
DEFAULT_BETA
public static final double DEFAULT_BETA
- See Also:
- Constant Field Values
-
typeTopicCounts
public int[][] typeTopicCounts
-
tokensPerTopic
public int[] tokensPerTopic
-
docLengthCounts
public int[] docLengthCounts
-
topicDocCounts
public int[][] topicDocCounts
-
numIterations
public int numIterations
-
burninPeriod
public int burninPeriod
-
saveSampleInterval
public int saveSampleInterval
-
optimizeInterval
public int optimizeInterval
-
temperingInterval
public int temperingInterval
-
showTopicsInterval
public int showTopicsInterval
-
wordsPerTopic
public int wordsPerTopic
-
saveStateInterval
public int saveStateInterval
-
stateFilename
public java.lang.String stateFilename
-
saveModelInterval
public int saveModelInterval
-
modelFilename
public java.lang.String modelFilename
-
randomSeed
public int randomSeed
-
formatter
public java.text.NumberFormat formatter
-
printLogLikelihood
public boolean printLogLikelihood
-
-
Constructor Detail
-
ParallelTopicModel
public ParallelTopicModel(int numberOfTopics)
-
ParallelTopicModel
public ParallelTopicModel(int numberOfTopics, double alphaSum, double beta)
-
ParallelTopicModel
public ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta)
-
-
Method Detail
-
getAlphabet
public Alphabet getAlphabet()
-
getTopicAlphabet
public LabelAlphabet getTopicAlphabet()
-
getNumTopics
public int getNumTopics()
-
setNumTopics
public void setNumTopics(int numTopics)
Set or reset the number of topics. This method will not change any token-topic assignments, so it should only be used before initializing or restoring a previously saved state.
-
getData
public java.util.ArrayList<TopicAssignment> getData()
-
getTypeTopicCounts
public int[][] getTypeTopicCounts()
-
getTokensPerTopic
public int[] getTokensPerTopic()
-
setNumIterations
public void setNumIterations(int numIterations)
-
setBurninPeriod
public void setBurninPeriod(int burninPeriod)
-
setTopicDisplay
public void setTopicDisplay(int interval, int n)
-
setRandomSeed
public void setRandomSeed(int seed)
-
setOptimizeInterval
public void setOptimizeInterval(int interval)
Interval for optimizing Dirichlet hyperparameters
-
setSymmetricAlpha
public void setSymmetricAlpha(boolean b)
-
setTemperingInterval
public void setTemperingInterval(int interval)
-
setNumThreads
public void setNumThreads(int threads)
-
setSaveState
public void setSaveState(int interval, java.lang.String filename)
Define how often and where to save a text representation of the current state. Files are GZipped.- Parameters:
interval
- Save a copy of the state everyinterval
iterations.filename
- Save the state to this file, with the iteration number as a suffix
-
setSaveSerializedModel
public void setSaveSerializedModel(int interval, java.lang.String filename)
Define how often and where to save a serialized model.- Parameters:
interval
- Save a serialized model everyinterval
iterations.filename
- Save to this file, with the iteration number as a suffix
-
addInstances
public void addInstances(InstanceList training)
-
initializeFromState
public void initializeFromState(java.io.File stateFile) throws java.io.IOException
- Throws:
java.io.IOException
-
buildInitialTypeTopicCounts
public void buildInitialTypeTopicCounts()
-
optimizeAlpha
public void optimizeAlpha(WorkerCallable[] callables)
-
temperAlpha
public void temperAlpha(WorkerCallable[] callables)
-
optimizeBeta
public void optimizeBeta(WorkerCallable[] callables)
-
estimate
public void estimate() throws java.io.IOException
- Throws:
java.io.IOException
-
maximize
public void maximize(int iterations)
This method implements iterated conditional modes, which is equivalent to Gibbs sampling, but replacing sampling from the conditional distribution with taking the maximum topic. It tends to converge within a small number of iterations for models that have reached a good state through Gibbs sampling.
-
getSortedWords
public java.util.ArrayList<java.util.TreeSet<IDSorter>> getSortedWords()
Return an array of sorted sets (one set per topic). Each set contains IDSorter objects with integer keys into the alphabet. To get direct access to the Strings, use getTopWords().
-
getTopWords
public java.lang.Object[][] getTopWords(int numWords)
Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order. These are returned as Objects, but will probably be Strings.- Parameters:
numWords
- The maximum length of each topic's array of words (may be less).
-
printTopWords
public void printTopWords(java.io.File file, int numWords, boolean useNewLines) throws java.io.IOException
- Throws:
java.io.IOException
-
printTopWords
public void printTopWords(java.io.PrintStream out, int numWords, boolean usingNewLines)
-
displayTopWords
public java.lang.String displayTopWords(int numWords, boolean usingNewLines)
-
topicXMLReport
public void topicXMLReport(java.io.PrintWriter out, int numWords)
-
topicPhraseXMLReport
public void topicPhraseXMLReport(java.io.PrintWriter out, int numWords)
-
printTypeTopicCounts
public void printTypeTopicCounts(java.io.File file) throws java.io.IOException
Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.- Throws:
java.io.IOException
-
printTopicWordWeights
public void printTopicWordWeights(java.io.File file) throws java.io.IOException
- Throws:
java.io.IOException
-
printTopicWordWeights
public void printTopicWordWeights(java.io.PrintWriter out) throws java.io.IOException
Print an unnormalized weight for every word in every topic. Most of these will be equal to the smoothing parameter beta.- Throws:
java.io.IOException
-
getTopicProbabilities
public double[] getTopicProbabilities(int instanceID)
Get the smoothed distribution over topics for a training instance.
-
getTopicProbabilities
public double[] getTopicProbabilities(LabelSequence topics)
Get the smoothed distribution over topics for a topic sequence, which may be from the training set or from a new instance with topics assigned by an inferencer.
-
printDocumentTopics
public void printDocumentTopics(java.io.File file) throws java.io.IOException
- Throws:
java.io.IOException
-
printDenseDocumentTopics
public void printDenseDocumentTopics(java.io.PrintWriter out)
-
printDocumentTopics
public void printDocumentTopics(java.io.PrintWriter out)
-
printDocumentTopics
public void printDocumentTopics(java.io.PrintWriter out, double threshold, int max)
- Parameters:
out
- A print writerthreshold
- Only print topics with proportion greater than this numbermax
- Print no more than this many topics
-
getSubCorpusTopicWords
public double[][] getSubCorpusTopicWords(boolean[] documentMask, boolean normalized, boolean smoothed)
-
getTopicWords
public double[][] getTopicWords(boolean normalized, boolean smoothed)
-
getDocumentTopics
public double[][] getDocumentTopics(boolean normalized, boolean smoothed)
-
getTopicDocuments
public java.util.ArrayList<java.util.TreeSet<IDSorter>> getTopicDocuments(double smoothing)
-
printTopicDocuments
public void printTopicDocuments(java.io.PrintWriter out)
-
printTopicDocuments
public void printTopicDocuments(java.io.PrintWriter out, int max)
- Parameters:
out
- A print writercount
- Print this number of top documents
-
printState
public void printState(java.io.File f) throws java.io.IOException
- Throws:
java.io.IOException
-
printState
public void printState(java.io.PrintStream out)
-
modelLogLikelihood
public double modelLogLikelihood()
-
getInferencer
public TopicInferencer getInferencer()
Return a tool for estimating topic distributions for new documents
-
getProbEstimator
public MarginalProbEstimator getProbEstimator()
Return a tool for evaluating the marginal probability of new documents under this model
-
write
public void write(java.io.File serializedModelFile)
-
read
public static ParallelTopicModel read(java.io.File f) throws java.lang.Exception
- Throws:
java.lang.Exception
-
-