Gaussian Mixtures and EM Algorithm

\(\newcommand{\x}{\boldsymbol{x}} \newcommand{\X}{\boldsymbol{X}} \newcommand{\z}{\boldsymbol{z}} \newcommand{\Z}{\boldsymbol{Z}} \newcommand{\mmu}{\boldsymbol{\mu}} \newcommand{\ssigma}{\boldsymbol{\Sigma}} \newcommand{\ttheta}{\boldsymbol{\theta}} \newcommand{\ppi}{\boldsymbol{\pi}}\) The Gaussian distribution is perhaps one of the most important and widely used distributions for data modeling. One limitation, though, is that it is unimodal i.e., has a single mode (or bump). This characteristic may prevent it from capturing multimodality which is not uncommon in real data. An obvious extension is to consider a convex combination of \(K\) Gaussian distributions which is known as a \(K\)-component Gaussian mixture whose PDF is given by

\[p(\x) = \sum_{k=1}^{K} \pi_k \mathcal{N}(\x | \mmu_k, \ssigma_k),\]

where \(\pi_k \geq 0\) are called mixing coefficients which necessarily sum to 1 to constitute a valid PDF. The parameters of this model are \(\ttheta := (\ppi, \mmu, \ssigma)\) where \(\ppi = (\pi_1, \ldots, \pi_K)\), \(\mmu = (\mmu_1 , \cdots , \mmu_K)\) and \(\ssigma = (\ssigma_1, \ldots, \ssigma_K)\).

In Gaussian mixture, each example \(\x\) is assumed to be generated from one of the components. For parameter inference, it is useful to introduce a multivariate \(K\)-dimensional binary variable \(\z\) associated with each \(\x\) with \(z_j=1\) if the \(j^{th}\) component is responsible for generating \(\x\). Note that there is exactly one \(1 \leq j \leq K\) such that \(z_j=1\); the rest is 0. This representation of \(\z\) is known as 1-of-K encoding.

With the introduction of \(\z\) the model can be rewritten as

\[p(\x) = \sum_\z p(\z) p(\x | \z),\]

where, using the fact that 1-of-K encoding is used, we have

\[\begin{align*} p(\z) &= \prod_{j=1}^K \pi_j^{z_j} \\ p(\x | \z) &= \prod_{j=1}^K \mathcal{N}(\x | \mmu_j, \ssigma_j)^{z_j}. \end{align*}\]

The latter definition of \(p(\x)\) gives exactly the same model as before only with the latent or hidden variable \(\z\) explicitly written. The variable \(\z\) is latent because in practice it is unobserved. If \(\z\) were to be observed, the parameter inference problem would be reduced to component-wise inference. That is, the solution \(\ttheta_k\) is given by the maximum likelihood estimate on the data generated by the \(k^{th}\) component. Notice that the mixing coefficient \(\pi_k\) is in fact the prior probability of \(\z\). Specifically, \(p(z_k = 1) = \pi_k\).

Maximum Likelihood Estimate for Gaussian Mixtures

The fact that \(\z\) is unobserved prevents us from estimating parameters in a component-wise manner as we do not know the generating component of any example. So, it makes sense to instead find the parameters \(\ttheta\) such that the incomplete-data log likelihood is maximized.

\[\mathcal{L} = \ln p(\X | \ppi, \mmu, \ssigma) = \sum_{i=1}^n \ln \sum_{j=1}^K \pi_k \mathcal{N}(\x_i | \mmu_j, \ssigma_j)\]

The term "incomplete" arises from the fact that we do not consider \(\ln p(\X, \Z | \ppi, \mmu, \ssigma)\) as \(\z\) is unobserved. To find the optimal \(\ttheta\), a straight gradient-ascent algorithm can be considered. However, as it turns out, the updating equations of parameters \((\ppi, \mmu, \ssigma)\) do not have a closed form, which can be challenging.

Expectation Maximization

A general algorithm to find a maximum likelihood solution for models with latent variables is known as expectation maximization (EM) algorithm. The core idea of the EM algorithm is to pretend that we know \(\z\) for each example \(\x\). The information about \(\z\) can be estimated (E step) once \(\ttheta\) is known. Likewise, if \(\z\) are known, \(\ttheta\) can be estimated (M step). So, the EM algorithm iteratively alternates between E step and M step until the convergence of either \(\ttheta\) or the log likelihood, starting from an initial guess of \(\ttheta\). As we shall see that the information of which component generates an example is encapsulated in the conditional probability of \(\z\) given \(\x\).

The derivation of EM starts by computing \(\frac{\partial \mathcal{L}}{ \partial \mmu_k}\), \(\frac{\partial \mathcal{L}}{\partial \ssigma_k}\), \(\frac{\partial \mathcal{L}}{\partial \pi_k}\) and equate them to \(\boldsymbol{0}\). A Lagrange multiplier is needed for \(\frac{\partial \mathcal{L}}{\partial \pi_k}\) since the constraint \(\sum_j \pi_j =1\) must be satisfied. Interestingly, it turns out that in all the three cases the conditional probability of \(\z\) given \(\x\) shows up in the update equations:

\[\begin{align*} \gamma(z_k | \ttheta) &= p(z_k=1|\x) \\ &= \frac{\pi_k \mathcal{N}(\x | \mmu_k, \ssigma_k)}{ \sum_{j=1}^K \pi_j \mathcal{N}(\x | \mmu_j, \ssigma_j) } \\ &= \frac{p(\z)p(\x | \z)}{p(\x)} \end{align*}.\]

This conditional probability can be interpreted as the responsibility that the \(k^{th}\) component generates \(\x\). In fact, this conditional probability is the one being updated in the E step i.e., compute \(\gamma(z_k| \ttheta^{old})\) where \(\ttheta^{old}\) is the updated parameters in the previous M step. The complete update equations for \(\ppi, \mmu\) and \(\ssigma\) are as follows. These update equations will form the M step:

\[\begin{align*} \mmu_k^{new} &= \frac{1}{n_k} \sum_{i=1}^n \gamma(z_{ik}) \x_i, \\ \ssigma_k^{new} &= \frac{1}{n_k}\sum_{i=1}^n \gamma(z_{ik}) (\x_i-\mmu_k^{new})(\x_i-\mmu_k^{new})^\top, \\ \ppi_k^{new} &= \frac{n_k}{n}, \end{align*}\]

where \(n_k = \sum_{i=1}^n \gamma(z_{ik}).\)

It is very important to note that the equations do not constitute a closed form update as \(\gamma\) depends on \(\ttheta\). Nonetheless, if we treat \(\gamma\) as if it is independent of \(\ttheta\), the update equations are very intuitive. For example, \(\mmu_k\) is updated with weighted average of sample, where each weight is given by the responsibility that the \(k^{th}\) component generates the corresponding example. The value \(n_k\) can be interpreted as an effective sample size that the \(k^{th}\) component is responsible. Note that \(n = \sum_k n_k\).

Closing Remarks

Reference: Section 9.2 of the PRML book by Chris Bishop.