Gaussian Mixture Models Explained

Here I write about GMMs because it's a method of clustering that serves as a basis for more advanced concepts.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import seaborn as sns
import pandas as pd
from math import sqrt, log, exp, pi
from random import uniform
%matplotlib inline
In [2]:
#read our dataset
df = pd.read_csv("data/bimodal_example.csv")
data = df.x
data.head()
Out[2]:
0    0.252851
1   -1.034562
2    3.319558
3    4.552363
4   -0.775995
Name: x, dtype: float64
Here, the problem is unsupervised, and it is about clustering.
A universally used generative unsupervised clustering method is Gaussian Mixture Model (GMM), which is also known as "EM Clustering."
These are 2 types of clustering:
  • Hard Clustering: for each data point assign a cluster
  • Soft Clustering: for each data point assign a probability distribution over clusters

Fitting a single gaussian

In a one-dimensional case, there's this probability distribution that might represent the whole dataset.
$$p(x | \theta) = \mathcal N(x \ | \ \mu, \sigma^2)$$
In [3]:
gfit_mean = np.mean(data)
gfit_sigma = np.std(data)

x = np.linspace(-6, 8, 200)
g_single = stats.norm(gfit_mean, gfit_sigma).pdf(x)
sns.distplot(data, bins=20, kde=False, norm_hist=True)
plt.plot(x, g_single, label='single gaussian')
plt.legend();
Two Gaussians would represent the data better.

Another example

In a multidimensional case, this would be the probability distribution.
$$p(x | \theta) = \mathcal N(x \ | \ \mu, \Sigma^2)$$
In [4]:
sample_gauss_1 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 100]], 100)
sample_gauss_3 = np.random.multivariate_normal([10, 10], [[10, 0], [0, 10]], 100)

data_2d = np.vstack((sample_gauss_1, sample_gauss_3))
plt.scatter(data_2d[:,0], data_2d[:,1])
Out[4]:
<matplotlib.collections.PathCollection at 0x109752f28>
In [5]:
gfit_mean = (np.mean(data_2d[:,0]), np.mean(data_2d[:,1]))
gfit_sigma = np.cov(data_2d.T)

ngrid = 10
xlim = (-15, 25)
ylim = (-15, 25)
x = np.linspace(xlim[0], xlim[1], ngrid, endpoint=True)
y = np.linspace(ylim[0], ylim[1], ngrid, endpoint=True)
X, Y = np.meshgrid(x, y)
xy = np.vstack([X.ravel(), Y.ravel()]).T

z = np.array([stats.multivariate_normal.pdf(xy_elem, gfit_mean, gfit_sigma) for xy_elem in xy])
z = z.reshape((ngrid, ngrid))
plt.scatter(data_2d[:,0], data_2d[:,1])
plt.contour(x, y, z, cmap=plt.cm.jet)
Out[5]:
<matplotlib.contour.QuadContourSet at 0x10977a518>
But there is a more approximate Gaussian that can be used to generate the data: a mixture of Gaussians. In the last example, 2 Gaussians would do the trick.

Notation

We'll use the one-dimensional example, so this will be the dataset
$$ D = \{(x^{(i)})\}_{i=1}^{N} $$

Defining a model

We need two Normal distributions $\mathcal N(\mu_1, \sigma_1^2)$ and $\mathcal N(\mu_2, \sigma_2^2)$.
Hence, there are 5 parameters: 4 of them are from Normal distributions, and 1 more for the probability of choosing one of them.
Let $\psi$ be the probability that the data comes from the first Normal, the parameter in this model is $\theta = (\psi, \mu_1, \sigma_1^2, \mu_2, \sigma_2^2)$.
This is the probability of the occurrence of a data point
$$ p(x | \theta) = \psi \mathcal N(x | \mu_1, \sigma_1^2) + (1 - \psi) \mathcal N (x | \mu_2, \sigma_2^2) $$

Optimization problem

Max. likelihood
$$ \begin{align} \max_\theta &\prod_{i=1}^N p(x_i|\theta) \\ \max_\theta &\prod_{i=1}^N \psi \mathcal N(x_i | \mu_1, \sigma_1^2) + (1 - \psi) \mathcal N (x_i | \mu_2, \sigma_2^2) \end{align} $$

Introducing a latent variable (LV)

Each data point was generated by using some LV "z" that will take 2 values: 1 and 2.
Expectation Maximization (EM) can be used to estimate $\theta$

EM Algorithm for GMM

Continuing with the mixture of normals model as our example, we can apply the EM algorithm to estimate $\theta = \{\mu, \sigma, \psi\}$.
This was copied from https://github.com/fonnesbeck/Bios8366/ Initiazlize all parameters $\theta_0 = \{\mu_0, \sigma_0, \psi_0\}$
Repeat until convergence:
  • E-step: guess the values of $\{z_i\}$
Compute probabilities of group membership: $w_{ij} = P(z_i = j | x_i, \theta)$ for each group $j=1,\ldots,k$. This is done via Bayes' formula:
$$P(z_i = j | x_i) = \frac{P(x_i | z_i=j) P(z_i=j)}{\sum_{l=1}^k P(x_i | z_i=l) P(z_i=l)}$$
$P(z_i=j|\theta) = \psi_j$
$P(x|z_i=j, \theta) = \mathcal N (x | \mu_j, \sigma_j)$
$\theta$ has been dropped for notational convenience.
  • M-step: update estimates of parameters $\theta$
$$\begin{aligned}\psi_j &= \frac{1}{m} \sum_i w_{ij} \\ \mu_j &= \frac{\sum_i w_{ij} x_i}{\sum_i w_{ij}} \\ \sigma_j &= \frac{\sum_i w_{ij}(x_i - \mu_j)^2}{\sum_i w_{ij}} \end{aligned}$$
In [6]:
class Gaussian:
    "Model univariate Gaussian"
    
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma

    def pdf(self, datum):
        u = (datum - self.mu) / abs(self.sigma)
        y = (1 / (sqrt(2 * pi) * abs(self.sigma))) * exp(-u * u / 2)
        return y
    
    def __repr__(self):
        return 'Gaussian({0:.03}, {1:.03})'.format(self.mu, self.sigma)
    
class GaussianMixture:
    "Model mixture of two univariate Gaussians and their EM estimation"

    def __init__(self, data, mu1_ini, sigma1_ini, mu2_ini, sigma2_ini, psi=.5):
        self.data = data
        self.one = Gaussian(mu1_ini, sigma1_ini)
        self.two = Gaussian(mu2_ini, sigma2_ini)
        self.psi = psi
        
    def Estep(self):
        self.loglike = 0.
        for x in self.data:
            wp1 = self.one.pdf(x) * self.psi
            wp2 = self.two.pdf(x) * (1. - self.psi)
            den = wp1 + wp2
            # normalize
            wp1 /= den
            wp2 /= den
            # add into loglike
            self.loglike += log(wp1 + wp2)
            yield(wp1, wp2)

    def Mstep(self, weights):
        # compute denominators
        (left, right) = zip(*weights)
        one_den = sum(left)
        two_den = sum(right)
        # compute new means
        self.one.mu = sum(w * d / one_den for (w, d) in zip(left, data))
        self.two.mu = sum(w * d / two_den for (w, d) in zip(right, data))
        # compute new sigmas
        self.one.sigma = sqrt(sum(w * ((d - self.one.mu) ** 2)
                                  for (w, d) in zip(left, data)) / one_den)
        self.two.sigma = sqrt(sum(w * ((d - self.two.mu) ** 2)
                                  for (w, d) in zip(right, data)) / two_den)
        # compute new psi
        self.psi = one_den / len(data)
        
    def fit(self, n_iterations, verbose=True):
        best_loglike = float('-inf')
        
        for i in range(n_iterations):
            if verbose:
                print('{0} loglike: {1}, {2}'.format(i, best_loglike, self))

            self.Mstep(self.Estep())

            if self.loglike > best_loglike:
                best_loglike = self.loglike

    def pdf(self, x):
        return (self.psi)*self.one.pdf(x) + (1-self.psi)*self.two.pdf(x)

    def __str__(self):
        return 'GaussianMixture: {0}, {1}, psi={2:.03})'.format(self.one, 
                                                        self.two, 
                                                        self.psi)
In [7]:
n_iterations = 100
GMM = GaussianMixture(data, -5., 10., 5., 10.)
GMM.fit(n_iterations)
0 loglike: -inf, GaussianMixture: Gaussian(-5.0, 10.0), Gaussian(5.0, 10.0), psi=0.5)
1 loglike: -1.1102230246251565e-15, GaussianMixture: Gaussian(1.69, 1.8), Gaussian(2.0, 1.74), psi=0.454)
2 loglike: -1.1102230246251565e-15, GaussianMixture: Gaussian(1.68, 1.82), Gaussian(2.02, 1.72), psi=0.454)
3 loglike: -1.1102230246251565e-15, GaussianMixture: Gaussian(1.66, 1.83), Gaussian(2.03, 1.7), psi=0.454)
4 loglike: -9.992007221626409e-16, GaussianMixture: Gaussian(1.63, 1.84), Gaussian(2.06, 1.69), psi=0.454)
5 loglike: -9.992007221626409e-16, GaussianMixture: Gaussian(1.6, 1.85), Gaussian(2.08, 1.67), psi=0.454)
6 loglike: -9.992007221626409e-16, GaussianMixture: Gaussian(1.56, 1.85), Gaussian(2.11, 1.66), psi=0.455)
7 loglike: -6.661338147750939e-16, GaussianMixture: Gaussian(1.52, 1.86), Gaussian(2.15, 1.64), psi=0.455)
8 loglike: -6.661338147750939e-16, GaussianMixture: Gaussian(1.48, 1.86), Gaussian(2.19, 1.62), psi=0.455)
9 loglike: -6.661338147750939e-16, GaussianMixture: Gaussian(1.43, 1.87), Gaussian(2.23, 1.6), psi=0.456)
10 loglike: -6.661338147750939e-16, GaussianMixture: Gaussian(1.37, 1.87), Gaussian(2.28, 1.57), psi=0.457)
11 loglike: -6.661338147750939e-16, GaussianMixture: Gaussian(1.31, 1.87), Gaussian(2.33, 1.54), psi=0.458)
12 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(1.25, 1.86), Gaussian(2.39, 1.5), psi=0.46)
13 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(1.18, 1.85), Gaussian(2.45, 1.46), psi=0.462)
14 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(1.1, 1.83), Gaussian(2.53, 1.41), psi=0.465)
15 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(1.02, 1.81), Gaussian(2.61, 1.35), psi=0.468)
16 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.927, 1.77), Gaussian(2.7, 1.28), psi=0.473)
17 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.833, 1.73), Gaussian(2.81, 1.19), psi=0.478)
18 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.735, 1.66), Gaussian(2.92, 1.09), psi=0.483)
19 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.64, 1.59), Gaussian(3.02, 0.979), psi=0.486)
20 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.554, 1.51), Gaussian(3.1, 0.894), psi=0.485)
21 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.482, 1.44), Gaussian(3.15, 0.843), psi=0.481)
22 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.42, 1.37), Gaussian(3.17, 0.82), psi=0.475)
23 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.365, 1.32), Gaussian(3.18, 0.81), psi=0.467)
24 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.314, 1.28), Gaussian(3.17, 0.808), psi=0.458)
25 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.265, 1.24), Gaussian(3.16, 0.809), psi=0.449)
26 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.218, 1.21), Gaussian(3.15, 0.812), psi=0.439)
27 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.173, 1.18), Gaussian(3.14, 0.816), psi=0.431)
28 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.131, 1.14), Gaussian(3.13, 0.82), psi=0.422)
29 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.0914, 1.11), Gaussian(3.12, 0.824), psi=0.415)
30 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.0558, 1.09), Gaussian(3.11, 0.829), psi=0.408)
31 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(0.0239, 1.06), Gaussian(3.1, 0.834), psi=0.401)
32 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.004, 1.04), Gaussian(3.09, 0.839), psi=0.396)
33 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.028, 1.02), Gaussian(3.08, 0.843), psi=0.391)
34 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.0483, 1.0), Gaussian(3.07, 0.848), psi=0.387)
35 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.0653, 0.99), Gaussian(3.06, 0.852), psi=0.383)
36 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.0792, 0.979), Gaussian(3.06, 0.855), psi=0.38)
37 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.0906, 0.97), Gaussian(3.05, 0.858), psi=0.378)
38 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.0998, 0.963), Gaussian(3.05, 0.861), psi=0.376)
39 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.107, 0.957), Gaussian(3.04, 0.863), psi=0.375)
40 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.113, 0.953), Gaussian(3.04, 0.865), psi=0.373)
41 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.118, 0.949), Gaussian(3.04, 0.867), psi=0.372)
42 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.122, 0.946), Gaussian(3.04, 0.868), psi=0.371)
43 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.125, 0.944), Gaussian(3.03, 0.869), psi=0.371)
44 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.127, 0.942), Gaussian(3.03, 0.87), psi=0.37)
45 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.129, 0.941), Gaussian(3.03, 0.871), psi=0.37)
46 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.13, 0.94), Gaussian(3.03, 0.871), psi=0.37)
47 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.132, 0.939), Gaussian(3.03, 0.872), psi=0.369)
48 loglike: -4.440892098500626e-16, GaussianMixture: Gaussian(-0.133, 0.938), Gaussian(3.03, 0.872), psi=0.369)
49 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.133, 0.938), Gaussian(3.03, 0.872), psi=0.369)
50 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.134, 0.937), Gaussian(3.03, 0.873), psi=0.369)
51 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.134, 0.937), Gaussian(3.03, 0.873), psi=0.369)
52 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.135, 0.937), Gaussian(3.03, 0.873), psi=0.369)
53 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.135, 0.936), Gaussian(3.03, 0.873), psi=0.368)
54 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.135, 0.936), Gaussian(3.03, 0.873), psi=0.368)
55 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
56 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
57 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
58 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
59 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
60 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
61 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.873), psi=0.368)
62 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
63 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
64 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
65 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
66 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
67 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
68 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
69 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
70 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
71 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
72 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
73 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
74 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
75 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
76 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
77 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
78 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
79 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
80 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
81 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
82 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
83 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
84 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
85 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
86 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
87 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
88 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
89 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
90 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
91 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
92 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
93 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
94 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
95 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
96 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
97 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
98 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
99 loglike: -1.1102230246251578e-16, GaussianMixture: Gaussian(-0.136, 0.936), Gaussian(3.03, 0.874), psi=0.368)
In [8]:
x = np.linspace(-6, 8, 200)
sns.distplot(data, bins=20, kde=False, norm_hist=True)
g_both = [GMM.pdf(e) for e in x]
plt.plot(x, g_both, label='gaussian mixture');
plt.legend();
This can be seen as 2 clusters.
After the posterior of a latent variable $z$ was computed $P(z_i = j | x_i, \theta)$, a better fit was found.