Why I don't recommend stochastic variational Bayes for topic models

As we approach the 20th anniversary of LDA, it's clear that topic models have gained a place in the data science toolkit. But I often hear comments like "topic models don't work" or even suspicions that the models I and my colleagues show are "too good". What explains this lack of confidence? I have a suspicion.

There are a large number of topic modeling algorithms. Most of them are about equally good, but there's one --- stochastic variational Bayes --- that often produces noticeably worse output in the settings that many people are using it. It also happens to be the default implementation in the popular and otherwise excellent gensim python library. I'd like to explain why I think this happens.

What a topic model does

The objective that a topic model optimizes is to represent a collection of documents as combinations of $K$ topics. One way to describe this representation is to consider each individual word in each document and assign a proportional "responsibility" for that word to each of the topics. This token-level association is what gives topic models the ability to represent words in context: an instance of the word "play" in a document about a band might be 90% associated with a topic about music, while an instance of the same word in a document about a soccer game might be 90% associated with a topic about sports.

If you know these token-level associations, you can then derive representations for topics and for documents. For a topic, you can get a distribution over words by counting up all the word tokens that the topic is responsible for in all the documents. Similarly, for each document you can get a distribution over topics by adding up the assignment values for each topic for all the words in the document.

You can also go the other way, and get token-level assignments from topic- and document-level distributions. For a word $w$ in a document $d$, you can multiply the probability of word $w$ given each topic $k$ by the probability of topic $k$ given document $d$. That gives you the joint probability of $w$ and each $k$, so renormalizing those values sum to 1.0 gives you the conditional probability of $k$ given $w$ and $d$. For example, the probability of "play" might be roughly equal in the sports and music topics, but if the document has higher probability of music, perhaps because of other more unambiguous words like "stage" or "guitar", then the word token will have greater association with music. Similarly, a document about player salary negotiations might have equal probability of topics about sports and business, but if the word "goalkeeper" only has high probability in sports, it will have greater association with that topic.

If we combine these two steps (local assignments to global distributions, global distributions to local assignments) and repeat, we get an EM algorithm. This process gives us the almost mystical ability of topic models to go from random initialization to meaningful, interpretable topics. As long as we start with an initial set of variables that are not perfectly symmetrical (so that both directions would simply yield the same distribution for all topics), small differences in probability will slowly reinforce themselves and encourage documents that share similar words to have similar topic distributions, and words that tend to occur together to have high probability in the same topics. Qualitatively, EM algorithms tend to go through three phases: a period of slow improvement at the beginning when values are mostly random, a period of fast improvement where the algorithm finds a good gradient and makes larger steps, and a period of slow convergence.

An important observation about this algorithm is that zeros are forever. If either the probability of a topic in a document or the probability of a word in a topic go to zero, they can never again be non-zero. Because the local assignment of words to topics involves the product of these two values, if either of them is zero, the result will be zero. And since the global values are estimated from the sum of the local values, they will always be zero as well. This property can be beneficial, if we can use it to avoid doing computations that we know will be irrelevant, like calculating the probability of the cooking topic in a document about video games. But if we zero out certain possibilities too quickly, we may lock in a sub-optimal model before we have a chance to settle in on a better one. Bayesian topic models add a prior distribution, which means that the global distributions are the sums of proportional word allocations plus a small constant that keeps the probabilities away from zero even when there is currently no practical evidence that a word occurs in a topic or a topic occurs in a document.

Algorithms for topic models

You can derive a large number of popular algorithms with small variations of this process. Gibbs sampling generates the posterior distribution over topics for a word in a document from the Bayesian model but then picks a single topic from that distribution. This sampling process is equivalent to setting the assignment 100% to one topic. The algorithm also immediately updates the global variables after each token assignment update. Iterated conditional modes does the same, but always chooses the most probable topic rather than sampling. Alternatively, if we treat the occurrence of a word in a document as simply an edge between two nodes in a network, and ignore any distinction between nodes that correspond to words and nodes that correspond to documents, the Ball, Karrer, Newman Poisson community detection algorithm is identical to EM except for a different topic normalization step.

Gibbs sampling results in samples that can be averaged to approximate the intractable posterior distribution over topics. The variational Bayesian EM algorithm takes a different approach, and generates a set of tractable distributions whose product is as close as possible to that posterior distribution. In practice, this means that instead of setting the token assignments proportional to the product of the topic-word probability and the document-topic probability, we set them proportional to the exponential of the expectation of the log of those probabilities. The expectation of the log probabilities involves a special function known as digamma, the derivative of the log gamma function, where the gamma function is a continuous version of the factorial function.

Variational instability and digamma

I'm going to guess that all that was pretty intimidating. So here's the strange thing: while digamma itself is a complicated function that cannot be computed in closed form, the exponential of the digamma is very well approximated by... subtracting 0.5. The value for $x=2$ is 1.526. For 10 it's 9.504. For 100, 99.5004. Exponentials can't be negative, so the only place where this approximation doesn't work is when we get close to $x=\frac{1}{2}$. The function curves and flattens as $x$ gets close to zero. The value at $x=0.1$ is 0.00003.

In [ ]:
import numpy as np
from matplotlib import pyplot
from scipy.special import digamma
In [ ]:
figure, (zoomed_out, zoomed_in) = pyplot.subplots(1, 2, figsize=(15,4))

x = np.linspace(0,5,30)
digamma_x = np.exp(digamma(x))
x_minus_half = x - 0.5

zoomed_out.hlines(y=0, xmin=0, xmax=5, colors=["lightgray"])
zoomed_out.plot(x, digamma_x)
zoomed_out.plot(x, x_minus_half)

x = np.linspace(0,1,30)
digamma_x = np.exp(digamma(x))
x_minus_half = x - 0.5

zoomed_in.hlines(y=0, xmin=0, xmax=1, colors=["lightgray"])
zoomed_in.plot(x, digamma_x)
zoomed_in.plot(x, x_minus_half)

figure.suptitle("exp digamma vs subtracting one half")
pyplot.show()

So what does this do to our algorithm? Instead of adding up all the token assignments for a given word and a given topic and dividing by the sum over all word assignments for the topic, we're adding up all the token assignments for a given word and a given topic minus one half and dividing by the sum over all word assignments for the topic minus one half. If all of the numbers in that numerator are far enough from zero that the "subtract 0.5" approximation works, then this process has exactly the opposite effect of Bayesian smoothing, where we might add a constant to all the values and renormalize. If we add a small constant, we get a distribution closer to uniform, more "spread out". If we subtract a constant from every dimension, we get a less uniform, "spikier" distribution.

Things get a little more dangerous for values that are closer to zero. The exponentiated digamma essentially zeroes out any values that are less than about 0.1, and significantly reduces other values less than one half. The variational approximation is causing exactly the problem that we were trying to avoid by adding Bayesian priors. One indication of this is that the hyperparameters that are used in variational implementations tend to be an order of magnitude more than those used for Gibbs samplers. I set the default topic-word smoothing for Mallet at 0.01, while most VB papers use 0.1. The effect of increasing the smoothing parameter on these distributions is to push values away from "basically zero" (the value for 0.01 is about $10^{-44}$) to something that is small but noticeably not zero.

In [ ]:
distribution = np.array([0.1, 0.1, 0.2, 0.5, 0.1])

figure, (before, after) = pyplot.subplots(1, 2, figsize=(15,4))
before.bar(x=np.linspace(1,5,5), height=distribution)
before.set_title("A distribution before exp digamma")
after.bar(x=np.linspace(1,5,5), height=np.exp(digamma(distribution)))
after.set_title("The same distribution after exp digamma")
pyplot.show()

This "spikiness" of the variation distributions provides a perspective on the objective of variational inference. There are many ways to formulate the VB objective, but one of them is to minimize the KL divergence from the variational distribution to the true posterior distribution. The direction from matters. KL divergence gets very large if there is an event that the variational distribution thinks is possible, but the true posterior does not. It's like if you were talking about music to someone you wanted to impress, and you would be mortified if you said you liked a certain band, which it turned out they hated. The variation distribution therefore tends to be cautious, avoiding putting weight anywhere that it's not confident will also be supported by the true posterior.

VB works when the initialization is close enough to uniform. Even if the numbers are small, proportionally no one value dominates the others. During that initial "looking around" phase of EM, the algorithm never makes really big moves in parameters, so things stay balanced until they have a good reason not to be balanced.

Stochastic inference breaks this assumption. Like Gibbs sampling, stochastic inference differs from standard EM in the schedule by which certain parameters are updated. Rather than generating an update based on all the documents, stochastic inference generates an approximate update based on a small subset of documents. There's nothing wrong with that in principle --- stochastic optimization is arguably the most important and useful algorithm of the past 10 years. But it interacts poorly with the VB objective. Stochastic optimization works when a large number of small, low-quality jumps in parameters add up to a large, high-quality shift. If, however, a seemingly small change in parameters is filtered through an exponentiated digamma function, it can "lock in" a noisy, unreliable parameter update.

Getting great topics

I don't want to suggest that Stachastic VB is bad. Stochastic optimization is the foundation of modern ML. Variational methods have been successful recently in the context of autoencoders with Gaussian variational distributions over continuous parameters. And I love the exponential family natural gradient mathematics of Stochastic VB. But in practice, it's always been intended for use on massive collections of millions of documents, where there is so much data that instabilities can wash out over time, and other alternatives might not be feasible.

For the 3000 document corpora that are the workhorse of day-to-day topic modeling, it just isn't appropriate. The speed of SVB is mostly an illusion --- on a per-sweep basis, it's actually not that fast because of the digamma functions. It appears fast because it's just not doing many sweeps over the data. The default setting for gensim is actually one pass. I've also heard that restricting vocabulary sizes to a few thousand words makes VB work better, probably because rarer words are more likely to have low counts that could lock in to the digamma danger zone. But why reduce the expressivity of the model to accommodate a non-optimal algorithm? Sparse sampling methods have no trouble with million-word vocabularies.

Ultimately, getting good topic model results isn't hard, but it isn't free. When you have a collection on the order of up to a hundreds of thousands of documents, take a little more time (minutes vs. seconds), use a more stable algorithm, and get results that are actually useful.