Expectation-Maximization Algorithm in 10 minutes

A quick walk-through of Expectation-Maximization (EM) algorithm and its cousins.

Expectation-Maximization Algorithm in 10 minutes

Introduction

This post explains Expectation-Maximization (EM) algorithm from scratch fairly concisely.

EM algorithm is one of the most elegant and widely used machine learning methods but is sometimes not thoroughly explained in introductory courses. What is so elegant about EM is that, as we shall see, it originates from nothing but the most fundamental laws of probability.

Many variants of EM have been developed over the years, and an important class of statistical machine learning methods called Variational Inference also has a strong connection to EM. If you are unfamiliar with Variational Inference, you might have heard about some concrete examples, e.g., variational autoencoders (VAEs) and generative adversarial networks (GANs). Understanding EM will allow you to understand them well.

Since the core ideas of EM find many applications in both classical machine learning and deep neural networks, it is worthwhile to have an intuitive and thorough understanding of EM, which this note attempts to provide.

$$ \newcommand{\argmin}{\mathop{\mathrm{argmin}}} \newcommand{\argmax}{\mathop{\mathrm{argmax}}} \renewcommand{\vec}[1]{\boldsymbol{#1}} $$

Notations

  • Random variables $X$, probability distribution $P(X)$
  • Probability density function (PDF) $p(\cdot)$, evaluated at value $x$: $p(X=x)$ with $p(x)$ as a shorthand
  • PDF with parameter $\theta$ is noted as $p_\theta(x)$ or equivalently $p(x\vert \theta)$
  • Expectation of $f(x)$ according to distribution $P$: $\mathbb{E}_{x\sim P}\left[f(x)\right]$
  • A set is noted as $ {x_i} $ or calligraphic letter $\mathcal X$

Maximum likelihood

Supposed we had data coming from a distribution $P_D(X)$, and we want to come up with a model for $x$ parameterized by $\theta$: $p(x;\theta)$ or equivalent noted as $p_{\theta}(x)$ to best approximate the real data distribution. Further, assume all the data samples are independent and identically distributed (iid) with $P_D(X)$.

To find $\theta$ under a maximum likelihood scheme, we do

$$ \begin{equation} \begin{split} \hat{\theta}_{MLE} &= \argmax_{\theta} \ell(\theta) \\ &= \argmax_{\theta} \sum_{i} \log\left( p_{\theta}(x_i) \right) \end{split} \end{equation} $$

Motivation for EM

We might encounter situations where, in addition to observed data ${x_i}$, we have missing or hidden data ${z_i}$. It might literally be data that is missing for some reason. Or, more interestingly, it might be due to our modeling choice. We might prefer to have a model with a set of meaningful but hidden variables ${z_i}$ that help explain the "causes" of ${x_i}$. Examples of this category would be Gaussian (or other kinds of) mixture models and LDA.

Note to myself: examples when we introduces latent variables just for the sake of making the optimization problem easier?

In either case, we will need to have a model for calculating the joint distribution of $x$ and $z$, $p(x,z;\theta)$, which may arise from assumptions (in the case of missing data) or models of marginal density functions $p(z; \theta)$ and $p(x\vert z; \theta)$. In such cases, the log-likelihood can be expressed as

$$ \begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left( p_{\theta}(x_i) \right)\\ &= \sum_i \log\left( \sum_{z} p_{\theta}(x_i, Z=z) \right)\\ &= \sum_i \log\left( \sum_{z} p_{\theta}(x_i\vert Z=z)p_{\theta}(Z=z) \right) \end{split} \end{equation} $$

Direct maximization with respect to $\theta$ might be challenging due to the summation over $z$ inside the log. But the problem would be much easier if we knew the values of $z$. It is simply the original maximum likelihood problem with all data available.

$$ \begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left(p_{\theta}(x_i\vert Z=z_i)p_{\theta}(Z=z_i) \right) \\ &= \sum_i \log\left(p_{\theta}(x_i, z_i) \right) \end{split} \end{equation} $$

The collection of $({x_i}, {z_i})$ is called the complete data. Naturally, ${x_i}$ is the incomplete data and ${z_i}$ is the latent data/variable.

Roughly speaking, the EM algorithm is an iterative method that lets us guess $z_i$ based on $x_i$ (and the current estimate of model parameter $\hat\theta$). With the guessed "fill-in" $z_i$, we now have complete data, and we optimize the log-likelihood $\ell(\theta)$ with respect to $\theta$. Thus, we iteratively improve our guess of latent variable $z$ and parameter $\theta$. We repeat this process until convergence.

In slightly more detail, instead of guessing a single value $z$, we guess the distribution of $z$ given $x$, i.e., $p(z\vert x;\hat\theta)$. Then optimize the expected log-likelihood for complete data, i.e., $\sum_i \mathbb{E}_ {z \sim p(z\vert x_i;\hat\theta)}\log p _\theta (x_i, z)$, with respect to $\theta$ which serves as a proxy (lower bound) for the true objective $\sum_i \log p_{\theta}(x_i)$. Repeat until converge.

(Note, in fact, guessing a single value for $z$ is also a valid strategy. It corresponds to a variant of EM and is what we do in the well-known K-means algorithm, where we guess a "hard" label on each data point.)

The nice thing about EM is that it comes with a theoretical guarantee of monotonic improvement on the true objective, even though we directly work with a proxy (lower bound) of it. Note, however, the rate of convergence will depend on the problem, and the convergence is not guaranteed to be towards the global optima.

Formulation

As before, we start with the log-likelihood

$$ \begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left( p_{\theta}(x_i) \right) \\ &= \sum_i \log\left( \int p_{\theta}(x_i, z) dz \right)\\ &= \sum_i \log\left( \int \frac{p_{\theta}(x_i, z)}{q(z)} q(z) dz \right) \\ &= \sum_i \log\left( \mathbb{E}_{z \sim Q} \left[ \frac {p_{\theta}(x_i, z)}{q(z)} \right] \right)\\ &\ge \sum_i \mathbb{E}_{z \sim Q} \left[\log\left( \frac {p_{\theta}(x_i,z)}{q(z)} \right) \right]\\ \label{eq:jensen} \end{split} \end{equation} $$

Here I switched the summation over $z$ to integral, assuming $z$ is continuous, just to hint this is a possibility. The last step used Jensen's inequality, and the fact log function is strictly concave.

So far, we do not have any restrictions on the distribution $Q$, apart from $q(z)$ is a probability density function, and it is positive where $p_\theta(x_i,z)$ is.

Using the result above, let's define the last quantity as $\mathcal L(q,\theta)$. It is usually called ELBO (Evidence Lower BOund) as it is a lower bound of $\ell(\theta)$.

$$ \begin{equation} \mathcal L(q,\theta) = \sum_i \mathbb{E}_{z \sim Q} \left[\log\left( \frac {p_{\theta}(x_i,z)}{q(z)} \right) \right] \end{equation} $$

To reiterate what we have done so far: our goal is to maximize $\ell(\theta)$; we exchanged the place of the log and integral over $z$ and got a lower bound $\mathcal L$.

We can show that the difference between $\ell(\theta)$ and $\mathcal L(q,\theta)$ is

$$ \begin{equation} \begin{split} \ell(\theta) - \mathcal L(q,\theta) & = \sum_i \int q(z) \left(\log (p_\theta(x_i)) - \log\left(\frac{p_\theta(x_i,z)}{q(z)}\right)\right) dz\\ &= \sum_i \int q(z) \log\left(\frac{q(z)}{\frac{p_\theta(x_i,z)}{p_\theta(x)}}\right) dz \\ &= \sum_i \int q(z) \log\left(\frac{q(z)}{p_\theta(z\vert x_i)}\right) dz \\ &= \sum_i D_{KL}(q(z) \| p_\theta(z\vert x_i)) \end{split} \end{equation} $$

where we used the fact Kullback-Leibler (KL) divergence $D_{KL}$ is defined as

$$D_{KL}(P \| Q)= \int p(x) \log \left( \frac{p(x)}{q(x)} \right) dx = \mathbb{E}_{x\sim P}\left[\log(\frac{p(x)}{q(x)}\right]$$

In general, KL divergence is always nonnegative and is zero if and only if $q(x) = p(x)$. So in our case, the equality $\ell(\theta) = \mathcal L(q,\theta)$ holds if and only if $q(z) = p_\theta(z\vert x_i)$. When this happens, we say the bound is tight. In this case, it makes sense to note $q(z)$ as $q(z\vert x_i)$ to make the dependence on $x_i$ clear.

EM algorithm and monotonicity guarantee

The EM algorithm is remarkably simple, and it goes as follows.

  • E-step (of $t$-th iteration):
  • Let $q^t(z) = p(z \vert x_i; \hat\theta^{t-1})$, which is calculated as shown in Eq. $\ref{eq:E}$
  • Due to our particular choice of $q^t$, at current estimate of $\hat\theta^{t-1}$ the bond is tight: $\mathcal L(q^t,\hat\theta^{t-1}) = \ell(\hat\theta^{t-1})$
  • M-step
  • Maximize $\mathcal L(q^t,\theta)$ with respect to $\theta$, see Eq. $\ref{eq:M}$
  • This step improves ELBO by finding a better $\theta$: $\mathcal L(q^t,\theta^t) \ge \mathcal L(q^t,\theta^{t-1})$

The calculation in E-step is

$$ \begin{equation}\label{eq:E} p(z\vert x_i; \hat\theta^{t-1}) = \frac{p(x_i\vert z; \hat\theta^{t-1})p(z; \hat\theta^{t-1})}{\int p(x_i\vert z; \hat\theta^{t-1})p(z; \hat\theta^{t-1}) dz} \end{equation} $$

To spell out the function $\mathcal L(q^t,\theta)$ that we maximize in M-step.

$$ \begin{equation} \begin{split} \hat\theta^t &= \argmax_{\theta} \mathcal L(q^t,\theta) \\ &= \argmax_{\theta} \sum_i \mathbb{E}_{z \sim Q^t} \left[\log\left(p(x_i,z;\theta) \right) \right] \\ &= \argmax_{\theta} \sum_i \int p(z\vert x_i; \hat\theta^{t-1}) \log\left(p(x_i,z;\theta)\right) dz \\ \end{split} \label{eq:M} \end{equation} $$

With the preparation earlier, we can also easily show the theoretical guarantee of monotonic improvement over the optimization objective $\ell(\theta)$.

$$ \begin{equation}\label{eq:monotone} \ell(\theta^{t-1}) \underset{\mathrm{E-step}}{=} \mathcal L(q^t,\theta^{t-1}) \underset{\mathrm{M-step}}{\le} \mathcal L(q^t,\theta^t) \underset{\mathrm{Jensen}}{\le} \ell(\theta^{t}) \end{equation} $$

Why the "E" in E-step

By the way, it is called E-step because, in that step, we do the necessary calculation to figure out the form of $\mathcal L(q,\theta)$ as a function of $\theta$, which we then optimize in the M-step. The form of $\mathcal L(q,\theta)$ is the expectation of the log-likelihood of complete data over the estimated distribution of the latent variable $z$.

Expectation-Maximization as Maximization-Maximization

Because the particular choice $q^t(z)$ in E-step is to have diminishing $D_{KL}(q(z) | p_\theta(z\vert x_i))$, thus E-step can be viewed as maximizing $\mathcal L(q,\hat\theta^{t-1})$ with respect to $q$ and M-step as maximization with respect to $\theta$. So we are doing alternating maximization on the EBLO with respect to $q$ and $\theta$.

$$ \begin{equation} \begin{split} & \text{E-step:}\hspace{4pt}q^t(z) = \argmax_q \mathcal L(q,\hat\theta^{t-1})\\ & \text{M-step:}\hspace{4pt}\hat\theta^t = \argmax_\theta \mathcal L(q^t,\theta) \end{split} \end{equation} $$

This maximization-maximization view offers justification for partial E-step (when the required calculation in exact E-step is intractable) and partial M-step (i.e., only find a $\theta$ that increases the ELBO rather than maximizes it). Under this view, the direct maximization of ELBO as a goal offers a strong connection to Variational Inference, as will be discussed briefly below.

Example: Gaussian Mixture

In the context of the Gaussian Mixture Model (GMM), $z_i$ associated with $x_i$ takes the value ${1,2,\dots,n_{g}}$, where ${n_g}$ is the number of Gaussians in the model. Thus $z_i$ indicates which Gaussian cluster observed data point $x_i$ belongs to. The set of parameter $\theta$ includes those that parameterize the marginal distribution of $z$, $P(Z;\vec \pi)$. $\vec \pi = [\pi_1, \pi_2, \dots, \pi_{n_g}]$, with $\sum_i^{n_g} \pi_i = 1$ and $\pi_i > 0$. Also, $\theta$ include those parametrized the conditional distribution of $P(X \vert Z=z_i; \mu_i, \sigma_i) \sim \mathcal N(\mu_i, \sigma_i)$.

For a detailed walk-through, see Andrew Ng's CS229 lecture notes and video

Variants and extensions of EM

GEM and CEM

A popular variant to EM is that in Eq. $\ref{eq: M}$ we merely find a $\hat\theta^t$ that increases (rather than maximizes) $\mathcal L(q^t,\theta)$. It is easy to see $\ref{eq:monotone}$, and the monotonicity guarantee still holds in this situation. This algorithm is proposed in the original EM paper and called Generalized EM (GEM).

Another variant is the point-estimate version we mentioned earlier, where instead of having $q^t(z) = p(z\vert x_i; \hat\theta^{t-1})$ in the E-step, we take $z$ to be a single value - the most probable one, i.e., $\hat{z}^t=argmax_z p(z\vert x_i; \hat\theta^{t-1})$ or equivalently taking $q^t(z) = \delta(z-\hat{z}^t)$. In this case, the integral in $\ref{eq:M}$ is greatly simplified, but the first equality in $\ref{eq:monotone}$ does not hold anymore, and we lose the theoretical guarantee. This algorithm is also called Classification EM (CEM).

Stochastic EM

As we can see in Eq. $\ref{eq:M}$, we need to go through all data points in order to update $\theta$, which could be a long process for large data sets. In much of the same spirit as stochastic gradient descent, we could sample subsets of data and run the E- and M-step on these mini-batches. The same idea can be used for variational inference mentioned below on the update of global latent variables (such as $\theta$).

Variational inference

The computation of the optimal $q(z)$, i.e. $q(z) = p(z \vert x_i; \hat\theta_{t-1})$ in E-step might be intractable. Especially, the integral in the denominator of Eq. $\ref{eq:E}$ does not have a closed-form solution for many interesting models. In this case, we can view EM as maximization-maximization and try to come up with better and better $q(z)$ to improve the ELBO. In order to proceed with such variational optimization tasks, we need to specify the functional family $\mathcal Q$ from which we will choose $q(z)$. Depending on the assumptions, a number of interesting algorithms have been developed. The most popular one is probably mean-field approximation.

Note that in a typical variational inference framework, the parameter $\theta$ is treated as first-class variables that we would do inference on (i.e., getting $p(\theta\vert x)$) rather than taking a maximum likelihood single point estimation, so $\theta$ become part of the latent variables and absorbed into the notation $z$. Thus, $z$ includes global variables such as $\theta$ and local variables such as the latent labels $z_i$ associated with each data point $x_i$.

In the mean-field method, the constraint we put on $q(z)$ is that it factorizes, i.e., $q(z) = \prod_k q_k(z_k)$. This is saying that all latent variables are mutually independent by assumption. This seemingly simple assumption brings remarkable simplifications in the calculation of integrals and especially the expectations of log-likelihood involved. It leads to a coordinate ascent variational inference (CAVI) algorithm that allows closed-form iterative calculation for certain model families. The coordinate updates on local variables correspond to the E-step in EM, while the updates on global variables correspond to the M-step in EM.

For more about this topic, see D. M. Blei, A. Kucukelbir, and J. D. McAuliffe, "Variational Inference: A Review for Statisticians," J. Am. Stat. Assoc., vol. 112, no. 518, pp. 859–877, 2017.


References

Todo: add citation in text; for now just core dumped some references here

In no particular order:

  1. A. P. Dempster, N. M. Laird, and D. B. Rubin, "Maximum likelihood from incomplete data via the EM algorithm," J. R. Stat. Soc. Ser. B Methodol., vol. 39, no. 1, pp. 1–38, 1977.
  2. R. M. Neal and G. E. Hinton, "A View of the Em Algorithm that Justifies Incremental, Sparse, and other Variants," Learn. Graph. Model., pp. 355–368, 1998.
  3. J. A. Bilmes, "A Gentle Tutorial of the EM Algorithm and its Application to Parameter Estimation for Gaussian Mixture and Hidden Markov Models," ReCALL, vol. 1198, no. 510, p. 126, 1998.
  4. A. Roche, "EM algorithm and variants: an informal tutorial," pp. 1–17, 2011.
  5. M. R. Gupta, "Theory and Use of the EM Algorithm," Found. Trends® Signal Process., vol. 4, no. 3, pp. 223–296, 2010.
  6. M. Jordan, Z. Ghahramani, T. S. Jaakkola, and L. K. Saul, "Introduction to variational methods for graphical models," Mach. Learn., vol. 37, no. 2, pp. 183–233, 1999.
  7. M. J. Wainwright and M. Jordan, "Graphical Models, Exponential Families, and Variational Inference," Found. Trends® Mach. Learn., vol. 1, no. 1–2, pp. 1–305, 2007.
  8. M. Hoffman, D. M. Blei, C. Wang, and J. Paisley, "Stochastic Variational Inference," vol. 14, pp. 1303–1347, 2012.
  9. D. M. Blei, A. Kucukelbir, and J. D. McAuliffe, "Variational Inference: A Review for Statisticians," J. Am. Stat. Assoc., vol. 112, no. 518, pp. 859–877, 2017.
  10. S. Mohamed, "Variational Inference for Machine Learning," no. February 2015.
  11. Z. Ghahramani, "Variational Methods The Expectation Maximization ( EM ) algorithm," no. April 2003.