Expectation-Maximization-variational-derivation

Expectation-Maximization

following a variational derivation

Goal

Find Maximum Likelihood Estimation (MLE) (or Maximum a posteriori / MAP) when some data is missing.

Given $n$ iid-observations (data)

$$ {\bf X} = (\vec x_1, \vec x_2, \dots \vec x_n) $$

Note: $\vec x_i$ can be a sequence.

Model

$$ (\vec x, \vec z) \sim p_\theta $$

for some (unknown) $\theta \in \Theta$

  • $\vec z$: hidden (latent) random variables
  • typical model: $\vec z \rightarrow \vec x$, i.e. $p_\theta(\vec x) = \sum_{\vec z} p_\theta(\vec x \mid \vec z) p_\theta(\vec z)$

Goal

Maximization of the likelihood (or log-likelihood) of the parameters given the observations.

$$ \theta_{mle} = \text{arg} \max_\theta p({\bf X} \mid \theta) $$

Marginal likelihood of the parameters w.r.t. the observed data:

$$ \mathcal l(\theta) = p({\bf X} \mid \theta) = \prod_{i=1}^n \left(\sum_{\vec z_i} p(\vec x_i, \vec z_i \mid \theta )\right) $$

Note: For continuous latent variables the sum over $\vec z_i$ must be replaced by an integral.

$$ \mathcal l(\theta) = p({\bf X} \mid \theta) = \prod_{i=1}^n \left(\int_{\mathcal Z} p(\vec x_i, \vec z_i \mid \theta ) d \vec z_i\right) $$

Marginal log likelihood of the parameters w.r.t. the observed data

$$ L(\theta) = \log p({\bf X} \mid \theta) = \sum_{i=1}^n \left(\log \sum_{\vec z_i} p(\vec x_i, \vec z_i \mid \theta )\right) $$

The the log-likelihood is difficult to maximize: the sum is inside the log.

Complete data log likelihood:

$$ L_c(\theta) = \log p({\bf X, Z} \mid \theta) = \sum_{i=1}^n \log p(\vec x_i, \vec z_i \mid \theta ) $$

can not be computed, since $\vec z_i$ is unknown.

Here: for each observation $\vec x_i$ we a have a hidden variable $\vec z_i$.

Lower Bound

With arbitrary distributions for $q({\bf z}_i)$ over the hidden variables and using Jensen's inequality, we can obtain a lower bound for the log-likelihood:

$$ \begin{align} L(\theta) &= \sum_{i} \log p({\bf x}_i \mid \theta ) \\ & = \sum_{i} \log \int_{\mathcal Z} p({\bf x}_i,{\bf z}_i \mid \theta ) d {\bf z}_i = \\ & = \sum_{i} \log \int_{\mathcal Z} q({\bf z}_i) \frac{p({\bf x}_i, {\bf z_i}\mid \theta )}{q({\bf z}_i)} d {\bf z}_i \\ & \geq \sum_{i} \int_{\mathcal Z} q({\bf z}_i) \log \frac{p({\bf x}_i, {\bf z}_i \mid \theta )}{q({\bf z}_i)} d {\bf z}_i \\ & = \sum_{i} \mathbb E_{q({\bf z}_i)} \left[\log \frac{p({\bf x}_i, {\bf z}_i \mid \theta )}{q({\bf z}_i)}\right] \\ & = - \sum_{i} \mathcal D_{KL} \left( q({\bf z}_i) \mid \mid p({\bf x}_i, {\bf z}_i \mid \theta ) \right) \\ & = \mathcal L(q, \theta) \end{align} $$
  • We have a family of lower bounds: Different $q({\bf z}_i)$ choose other members of the family. Maximizing the lower bound w.r.t. the $q$'s for a fixed $\theta$ gives a "stronger" lower bound.
  • $\mathcal L(q, \theta)$: variational lower bound
  • Usual abuse of notation for the probability distributions: in general $q({\bf z}_i) \neq q({\bf z}_j)$ for $i \neq j$ (the argument determines the probability distribution).
  • Note: By using Jensen's inequality we pushed the $\log$ inside the integral/sum.

EM Algorithm

  • E-Step: $q^{(t+1)} \leftarrow \text{arg} \max_q \mathcal L(q,\theta^{(t)}) $
    Search in the family of lower bounds at the current position $\theta^{(t)}$ the strongest bound.
  • M-Step: $\theta^{(t+1)} \leftarrow \text{arg} \max_{\theta} \mathcal L(q^{(t+1)},\theta)$
    Maximize the current lower bound (given by $q^{(t+1)}$) w.r.t. $\theta$.

E-Step in detail

The Gap between the log-likelihood $L(\theta) $ and the lower bound $\mathcal L$ is

$$ \begin{align} \text{GAP} &= L(\theta) - \mathcal L (q, \theta)\\ &= \sum_{i} \log p({\bf x}_i \mid \theta ) - \sum_{i} \int_{\mathcal Z} q({\bf z_i}) \log \frac{p({\bf x}_i, {\bf z_i} \mid \theta )}{q({\bf z}_i)} d {\bf z}_i\\ &= \sum_{i} \log p({\bf x}_i \mid \theta )\int_{\mathcal Z} q({\bf z}_i) d{\bf z}_i - \sum_{i} \int_{\mathcal Z} q({\bf z}_i) \log \frac{p({\bf x}_i, {\bf z}_i \mid \theta )}{q({\bf z}_i)} d {\bf z}_i\\ &= \sum_{i} \left( \int_{\mathcal Z} q({\bf z}_i) \log p({\bf x}_i \mid \theta ) d{\bf z} - \int_{\mathcal Z} q({\bf z}_i) \log \frac{p({\bf x}_i, {\bf z}_i \mid \theta )}{q({\bf z}_i)} d {\bf z}_i \right)\\ &= \sum_{i} \left( \int_{\mathcal Z} \left( q({\bf z}_i) \log p({\bf x}_i \mid \theta ) - q({\bf z}_i) \log \frac{p({\bf x}_i, {\bf z}_i \mid \theta )}{q({\bf z}_i)} \right) d {\bf z}_i \right)\\ &= \sum_{i} \left( \int_{\mathcal Z} \left( q({\bf z}_i) \log \frac{p({\bf x}_i \mid \theta ) q({\bf z}_i) }{p({\bf x}_i, {\bf z}_i \mid \theta )} \right) d {\bf z}_i \right)\\ &= \sum_{i} \left( \int_{\mathcal Z} \left( q({\bf z}_i) \log \frac{p({\bf x}_i \mid \theta ) q({\bf z_i}) }{p({\bf z_i} \mid {\bf x_i}, \theta )p({\bf x_i}\mid \theta)} \right) d {\bf z}_i \right)\\ &= \sum_{_i} \left( \int_{\mathcal Z} \left( q({\bf z}_i) \log \frac{ q({\bf z}_i) }{p({\bf z}_i \mid {\bf x_i}, \theta )} \right) d {\bf z}_i \right)\\ &= \sum_{_i} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid p({\bf z}_i \mid {\bf x_i}, \theta )\right ) \end{align} $$

with the Kullback-Leiber divergence $\mathcal D_{KL}$.

We want to minimize the GAP (equals to maximizing the lower bound).

The Kullback-Leiber divergence $\mathcal D_{KL}$ is zero iff $q({\bf z}_i) = p({\bf z}_i \mid {\bf x_i}, \theta )$.

So, in the E-Step we set the variational distribution $q({\bf z}_i)$ to $p({\bf z}_i \mid {\bf x_i}, \theta )$ (if we can compute $p({\bf z}_i \mid {\bf x_i}, \theta )$)

Note: For each $i$ we have a (in general different) $q({\bf z}_i)$, i.e. we used the usual "abuse of notation": the variable ${\bf z}_i $ of the function $q({\bf z}_i)$ is used to identify the distribution.

Note: If ${\bf z}_i$ can take only a bunch of possible values, we can compute $p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )$ by using Bayes rule:

$$ p({\bf z}_i \mid {\bf x_i}, \theta )=\frac{p({\bf x}_i \mid {\bf z}_i, \theta) p({\bf z}_i )}{\sum_{{\bf z}'_i}p({\bf x}_i \mid {\bf z'}_i, \theta) p({\bf z'}_i)} $$

M-Step in detail

$$ \begin{align} \theta^{(t+1)} &= \text{arg} \max_{\theta} \mathcal L (q, \theta) \\ &=\text{arg} \max_{\theta} \left( \sum_{i} \int_{\mathcal Z} q({\bf z_i}) \log \frac{p({\bf x}_i, {\bf z_i} \mid \theta )}{q({\bf z}_i)} d {\bf z}_i \right)\\ &=\text{arg} \max_{\theta} \left(\sum_{i} \int_{\mathcal Z} q({\bf z_i}) \log {p({\bf x}_i, {\bf z_i} \mid \theta )} d {\bf z}_i - \sum_{i} \int_{\mathcal Z} {q({\bf z}_i)} \log q({\bf z}_i) d {\bf z}_i \right)\\ &=\text{arg} \max_{\theta} \left( \sum_{i} \int_{\mathcal Z} q({\bf z_i}) \log {p({\bf x}_i, {\bf z_i} \mid \theta )} d {\bf z}_i\right)\\ &=\text{arg} \max_{\theta} \sum_{i} \mathbb E_{q_i} \left[ \log {p({\bf x}_i, {\bf z_i} \mid \theta )} \right] \end{align} $$

Note: We have to maximize the expected complete data log-likelihood (expectation w.r.t. $q({\bf z}_i)$), i.e. the complete data log-likelihood by using our current guess of the distribution $q^{(t+1)}({\bf z}_i) = p({\bf z}_i \mid {\bf x}_i, \theta^{(t)} )$ (see E-step).

Convergence of EM

Lower bound:

$$ \log p(X \mid \theta^{(t+1)}) \geq \mathcal L(q^{(t+1)}, \theta^{(t+1)}) $$

and because of the optimization procedure (E-Step and M-Step) the following inequalities holds:

$$ \mathcal L(q^{(t+1)}, \theta^{(t+1)})\geq \mathcal L(q^{(t+1)}, \theta^{(t)}) \geq \mathcal L(q^{(t)}, \theta^{(t)}) $$

and for $q({\bf z}_i) = p({\bf z}_i \mid {\bf x_i}, \theta )$ $$ \mathcal L(q^{(t)}, \theta^{(t)}) = \log p(X \mid \theta^{(t)}) $$

Therefore $$ \log p(X \mid \theta^{(t+1)}) \geq \log p(X \mid \theta^{(t)}) $$

So is guaranteed that in each iteration the log likelihood increases (until convergence).

  • This can be used for debugging: the log likelihood must not increase.
  • EM converges to a (local) maximum (or a saddle point).

Summary EM Algorithm

  • E-Step: $q^{(t+1)}= \text{arg}\min_q\sum_{_i} \mathcal D_{KL}\left( q({\bf z}_i) \mid \mid p({\bf z}_i \mid {\bf x_i}, \theta^{(t)} )\right)$
  • M-Step: $\theta^{(t+1)} = \text{arg}\max_\theta \sum_{i} \mathbb E_{q_i^{(t+1)}} \left[ \log {p({\bf x}_i, {\bf z_i} \mid \theta )} \right] $

Literature

  • [Rad93] Radford M. Neal and Geoffrey E. Hinton: A New View of the EM Algorithm that Justifies Incremental and Other Variants, Learning in Graphical Models, 355-368, Kluwer Academic Publishers, 1993
  • A. Novikov, D. Polykovskiy: Online Course: Bayesian Methods for Machine Learning