import matplotlib.pyplot as plt
from random import random, sample
from numpy import argmin, mean, var
from math import pi, cos, sin

K = 4

ROWS, COLUMNS = 3, 4  # number of subplots to plot

def random_point(x, y, radius):
    ''' Generator a random point in the neighborhood of (x, y) '''
    angle = 2 * pi * random()
    r = radius * random() ** 2

    return x + r * cos(angle), y + r * sin(angle)

def random_points(n, x, y, radius):
    ''' Generator n random points '''
    for _ in range(n):
        yield random_point(x, y, radius)

def dist(p, q):
    ''' compute the square of the distance between p and q '''
    return sum((a - b)**2 for a, b in zip(p, q))

def k_means(points, k):
    centroid = sample(points, K)
    centroids = [ centroid ]

    while True:
        clusters = [[] for _ in centroid]
        for p in points:
            i = argmin([dist(p, c) for c in centroid])
            clusters[i].append(p)

        centroid = [tuple(map(mean, zip(*c))) for c in clusters]

        if centroid == centroids[-1]:
            break

        centroids.append(centroid)
        if min(len(c) for c in clusters) == 0:
            print("Not good - empty cluster")
            break

    return clusters, centroids
    
points = []
for args in [(100, 0, 0, 1), (100, 1, 2, 0.5), (100, -1 ,1, 0.25)]:
    points += random_points(*args)
    
plt.subplot(ROWS, COLUMNS, 1)
plt.plot(*zip(*points), 'r.')
plt.title("%s input points" % len(points))

for plot in range(2, ROWS * COLUMNS +1):
    plt.subplot(ROWS, COLUMNS, plot)

    clusters, centroids = k_means(points, K)
    
    distortion = sum(dist(p, q)
                         for p, c in zip(centroids[-1], clusters)
                         for q in c)
    
    plt.title("%s iterations\n%.1f distortion" % (len(centroids), distortion))
    for c in clusters:
        plt.plot(*zip(*c), '.')

    for path in zip(*centroids):
        plt.plot(*zip(*path), 'k.-')      # centroids + path
    plt.plot(*zip(*centroids[0]), 'ks')   # initial centroids
    plt.plot(*zip(*centroids[-1]), 'ko')  # final centroid

plt.tight_layout()
plt.show()

from scipy.cluster.vq import vq, kmeans, whiten

points = whiten(points)

plt.plot(*zip(*points), 'r.')
plt.plot(*zip(*kmeans(points, K)[0]), "bo")
plt.title("scipy.cluster.vq.kmeans")

plt.show()
