Variational Inference

Review of EM algorithm

In last section, we talked about how EM algorithm works and why it works in general. The idea is to find a proper latent varibale which can be integrated out to get back to joint distribution. The latent variable comes in to facilitate the math challenges when we are trying to find the optimal parameters for joint distributions. With the latent variable, we can find an point estimate of parameters which result in the highest joint likelihood.

However, EM does not always work. When EM does not work, VI comes into play to solve the issue. We will see some exapmples where EM does not work and build the transition from EM to VI.

Model Setup

In this section, we will see different model setups and how EM fails in which stage.

Model V1

In this version of setup, we have the simplest model setup. Formally, we have:

We assume each sample is iid and dimensions are matched. There is no prior distribution on x. Thus, we are in discriminative setting not generative setting where we have the prior model on x.

We can calculate the posterior on w using Bayes rule:

This does not require any iterative algrotihm and is very easy to calculate in the closed form. However, we notice that there are two parameters that we need to pre-define: $\alpha,\lambda$. These parameters can affect the final performance. In general, there are two ways to work on it:

1 We can use cross-validation on these two parameters and select the one witht the best performance.

2 In Bayesian, we can put a prior distribution on the parameters and learn it!

Model V2

So we want to pick up a conjugate gamma propr on $\alpha$. Then, we have:

In this case, $alpha$ becomes another model variable that we want to learn like w. We can do one the following:

1 Try to learn the full posterior on $\alpha$ and w (We cannot assume w and $\alpha$ are independent).

2 Learn a point estimate of w and $\alpha$ based on MAP inference.

3 Integrate out some variables to learn the others.

Let’s look at each one:

1 The full posterior of $\alpha$ and w is:

We realize the normalizing constant cannot be caluclated in the closed form. We can use other approximation method to approximate it such as Laplace or MCMC sampling.

2 We can do MAP over w and $\alpha$:

We can use coordinate descent algorithm to optimize each parameter. However, this will not tell us about uncertainty. Rather, it will only tell us an point estimate of the model parameters.

3 We can do marginal likelihood as:

With this marginal likelihood, we can:

(1) try to find posterior inference for $p(w\lvert y,x)\propto p(y\lvert w,x)p(w)$. However, this will not work out since $p(y\lvert w,x)$ is resulted from the integral of $\alpha$, which is not a Gaussian anymore. Rather, it is student-t distribution. This is not conjugate with prior anymore. So it does not work.

(2) Another option is to maximize $p(y,w\lvert x)$ over w using MAP using gradient method. This is where EM could possibly come in.

EM for Model V2

So we want to treat $\alpha$ as the latent variable. We are trying to find a point estimate of w to maximize the marginal distribution $p(y,w\lvert x)=\int p(y,w,\alpha\lvert x)\alpha$.

Based on the above discussion, we can write EM master equation:


Now, we should set $q_t(\alpha) = p(\alpha\lvert y,w_{t-1},x)$, and then calculate the loss function:


Then, we have to make sure that we can maximize the loss function in a closed form. Otherwise, there is no point of doing this. We can try and find (I will give the math later):

We can plug $\mathbb{E}_{q_t}[\alpha]$ into the above so that we get the updating form of w.

Note that there are a few things that are different than we did EM last time.

(1) In the previous case, we introduced a latent variable to reduce the math complexity. However, in this case, the latent variable $\alpha$ has the interpretation, which is related to the observation noise. So this latent varibale has its own meaning.

(2) In this case, we have w and $\alpha$ to learn. However, we make a compromise by learning a point estimate and a conditional posterior of $\alpha$. It is hard for us to learn point estimates of two variables. Surly, you could have done the reverse. That is, you can learn a point estimate of $\alpha$ and a conditional posterior of w.

Model V3

The EM algorithm works in model version 2. Let’s see a new version of model, namely version 3. In this setup, we have one more variable to work with.

Then, the marginal distribution is:

EM for Model V3

Note that we have two latent variables now instead of one in V2. The EM master equation is:

Note that $\ln p(y,w\lvert x)$ in this model setup is different from the one in V2, although the form is the same. That is, the underlying distribution is essentially different.


As usual, we have:

We can calculate the conditional posterior of both parameters:


We have the same form for the updating as:

The only difference is that we have the expectation for $\lambda$ as well. This worked out since we can factorize the full conditional posteriors of the two variables as:

This is not always true. In this case, we just make it factorizeable. Eventually, we will have a point estimate and full posterior of variables.

The question is that what if we cannot factorize or find the posterior of the variables. In this case, we have the problem. If we cannot calculate or factorize the posterior, we cannot move on with the EM algorithm. This is where VI might come in.

From EM to VI

In EM, we can learn posteriors of all variables except for the one that is learned as point estimate. So let’s make it learn all variables in posteriors. For model V3, we can have:

Note that different from EM, there is nothing to optimize on the left hand side. We can further write:

Again, there is nothing to optimize on the left hand side. In addition, we cannot complete our E-step since we cannot calculate the full posterior:

So let’s look at the VI equation:

In EM, we want to have a proper q distribution to find out an point estimate of w. Both left and right hand side of EM equation are optimized. In VI, there is nothing to optimize on the left hand side. Thus, in VI, we view this equation differently. In particular, we are more interested in q distributions and how to pick them up.

Let’s look at each term in the VI equation individually.

(1) $\ln p(y\lvert x)$: This is the marginal likelihood of the data given the model. We cannot find this. If we can, we can then find:

So we cannot explicitly calculate $p(y\lvert x)$. However, given data and model, we know that this value is always constant no matter how we change the model variables.

(2) $KL(p\lvert\rvert q): Again, we do not know how to calcuate this KL. However, we do know that this is a non-negative number and equals to zero iff p = q.

(3) Objective function: In this term, we pick up a good q so that we can calculate this term by completing the integral.

The goal of VI is to pick a good q that can be as close as possible to the posterior distribution $p(w,\alpha,\lambda\lvert y,x)$. There are some other ways to do it. For example, we can use Laplace approximation. However, in this case, we are more interested in inference algorithm.

To measure the closeness, we can keep using KL divergence as the tool. In other words, we want to minimize the KL value. It is clear that we cannot caluclate the KL directly. However, we know that the left hand side of VI equation is constant, and KL is a non-negative number. Thus, to minimize KL, on the other hand, we can maximize the VI objective function, which is the middle term in the VI equation. This also shows that in VI, we care about the q distribution not the latent variable in EM.

VI for Model V3

Let’s step back to the model setup version 3 to see how VI works there. In this case, we need to:

(1) define q distribution family over $\alpha, w, \lambda$. The parameters of q are supposed to be optimized. Note that in EM, we find q such that the marginal likelihood is correct.

(2) construct variational objective function:

where $\theta$ is the symbol of all parameters that are to be optimized. In this case, there is no $w,\alpha,\lambda$ by the end since it will be integrated out.

In EM, we have different story where the objective function $\mathcal{L}$ is the function over model variable. We can do MAP inference or ML inference on it depending if we have the prior of model variable. Also, the parameters of q distribution on the latent variable in EM are always known since we force them to some certain values. On the other hand, in VI, the parameters of q distributions are unknown and instead they are the parameters in the objective function that we are trying to optimize. The key point here is that although EM and VI have the similar form of objective function, they are completely different things in this case.

Now, we need to define q distributions. In general, we use something called “mean-field” assumption. In this approach, we can have:

Then, we can handle them one at a time. Note that if w is a vector, we can also write $q(w)=\prod_j q(w_j)$.

We will talk about how to define the best optimal one for each of them. The algorithm should still work no matter what distributions we define. In this case, we can define them the same as their priors.

We optimize them to be the best values that can make q to be as close as possible to the full posterior by optimizing the variational objective function.

The reason that we are doing the factorization is that it is hard to find a good distribution family for multiple parameters that have different support. For example, $\alpha$ is only valid in postive real number while w is valid in the whole real number domain. How do we define a distribution that take both as support? So it is better to consider them one by one.

Recall that in model setup version v2, we had $\lambda$ and $\alpha$ conditionally independent given w, which happened to let us write $q(\alpha,\lambda) = q(\alpha)q(lambda)$. It is no longer valid now since we use mean-field assumption. That means:

It indicates that KL is always larger than zero in VI.

Then, we should calulcate the objective function:

You can see how complicated this is. Note that the independent assumption that we made has helped us a lot in this case since each expectation is only w.r.t. the variable of itnerest. So we can move expectation inwards to make calculation easier.

The next step is to take derivative w.r.t. $a^{\prime},b^{\prime},e^{\prime},f^{\prime}, \mu^{\prime},\Sigma^{\prime}$ and use graident descent to optimize them one at a time.

Variational Inference

We will talk about variational inference formally. Let’s define the model variables to be $\theta_1,\theta_2,\dots,\theta_m$ and parameters for those variables to be $\psi_1,\psi_2,\dots,\psi_m$.

Mean-field Assumption

This is an important assumption to make variational inference work. In mean-field assumption, we split $\theta$ into m different groups. Note that each component can be in different space. Each q distribution can be written as $q(\theta_i\lvert\psi_i)$. So all the $\psi_i$ are the parameters that we need to learn from variational objective function.

In mean-field assumption, we assume that:

can approximate the posterior:

Variational objective function

With this mean-field definition, we can write variational objective function as:

We are trying to optimize all the $\psi_i$. There are two ways to make variational inference work out: direct method and optimal method. We will see how direct method works and appareciate the effectiveness of optimal method. Before that, we need to figure out what distribution family we should pick up for each q distribution.

Picking $q(\theta_i\lvert\psi_i)$

Essentially, we are facing with two key problems here:

1 How do we choose each $q_i = q(\theta_i\lvert\psi_i)$?

2 How do we get the best $\psi_i$ for each i.

Setup: we have a joint likelihood $p(X,\theta_1,\dots,\theta_m)$ and the mean-field assumption says that $p(\theta_1,\dots,\theta_m) \approx \prod_{i=1}^m q(\theta_i\lvert\psi_i)$

Let’s randomly pick $i\in [1,\dots,m]$ and correspondingly $q(\theta_i\lvert\psi_i)$. We expand our variational objective function as:

Since we only care about $\theta_i$ and $\psi_i$, we can treat the vartional objective as the function of these two parameters. Then, we can further write:

Math: From the first line to second line, we first merge two integrals. Then, we factor off $q(theta_i\lvert\psi_i)$. Last but not least, in order to use the property of ln function, we first exponential and take ln of joint log likelihood.

Let’s look at this variational objective function for a second. We are trying to figure out how we can pick up a proper $q(\theta_i\lvert\psi_i)$ so that the variational objective function can be maximized. For this goal, what we have done so far is taking only $q_i$ and re-write the variational objective function in terms of $q_i$. Thus, given that we only care about $q_i$, the goal turns out to be that how we can maximize the above variational objective function.

If we look at this form of variational objective function, we can see that this somewhat looks like KL divergenceexcept that the numerator is not a distribution and the order of numerator and denominator is reversed. For the frist issue, if it is not a distribution, we can make it a distribution by normailizing it. For the second issue, we can put a negative multuplier on ln. By doing these two things, we make up a KL divergence.

In particular, let’s define:

We can add and substract $\ln Z$ on the variational objective:


is a distribution of $\theta_i$. We can further write the objective as:

When can we get the max value of the above variational objective function? The answer is when those two distributions are the same. Thus, for a particular i, we should set:

This gives us the optimal distribution family we should take no matter what parameters are. Since i is arbitrary, we can say this works for any i. In addition, when we can write out the full distribution of each i, we can also find out what the optimal parameters are for each i. This sovles our second question as well.

On a high level, the general variational inference algorithm can be described as:

For iteration t:

1 For model variable index $i=1,\dots,m$, we set:

2 Evaluate the varational objective function using the updating q:

3 If the increase between $\mathcal{L}t$ and $\mathcal{L}{t-1}$ is below some threahold value, we terminate the process.

Next, we look at two different methods to complete our variational inference.

Direct Method

In this method, we explicitly define each $q_i$ and optimize objective function. This is a more complciated method than optimal method. If you don’t want to see the difference, you can skip tp optimal method.


The joint likelihood:

We can approximate the full posterior by picking up:

Then, we can write our varitional objective function as:

You can realize how complicated it has been and how complex it will be after we plug everyting in.

You can see how complicated it is. The next step is to take derivative of each variables that we are optimizing for. We can then use gradient descent algorithm for each optimization problem.

However, as said, this is too complicated to handle with. In practice, it is more encouraged to use the optimal method for variational inference.

Optimal Method


Based on the mean-field assumption, we can factorize the q as:

As indicated in our general variational procedure, we should then pick up the distribution family for each model variable.

For $q(\alpha)$:

Since we haven’t calulated $q(w)$, we have to carry this over fot the calculation. However, we can recognize that:

For $q(w)$:

So again we just carry over $q(\alpha)$ for now. We realize that:

We now know $q(w)$ is a Gaussian and $q(\alpha)$ is a Gamma. We can plug the respective expectation back to the formula we found and solve for the final expression.

By looking up online, we can find out that:

Put them all together:

VI for Bayesian Linear regression with unknown noise precision

1 Initialize $a_0^{\prime},b_0^{\prime},\mu_0^{\prime},\Sigma_0^{\prime}$ ranomly

2 For iteration $t = 1,\dots$:

  • Update $q(\alpha)$ by setting:
  • Update $q(w)$ by setting:
  • Evaluate $\mathcal{L}$ so assess convergence.

A few comments can be made here:

1 The order of updating each parameter shoudl not matter. Eventually, they will converge to whatever they should end up with.

2 It is sometimes unavoidable that we have to calculate the variational objective function which we have seen as fairly complicated. However, to check the convergence, we have to do so.

3 In EM, we get a point estimate of the model variable that we want to learn from data. In VI, we get the full distribution with each model varaible.