klips/python/k-means/k-means.py

439 lines
17 KiB
Python

################################################################################
# Author: Shaun Reed #
# About: K-Means clustering CLI #
# Contact: shaunrd0@gmail.com | URL: www.shaunreed.com | GitHub: shaunrd0 #
################################################################################
from ast import literal_eval
from itertools import chain
from matplotlib import pyplot as plt
from typing import List
import argparse
import math
import numpy as np
import random
import sys
################################################################################
# CLI Argument Parser
################################################################################
# ==============================================================================
def init_parser():
parser = argparse.ArgumentParser(
description='K-means clustering program for clustering data read from a file, terminal, or randomly generated',
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'clusters', metavar='CLUSTER_COUNT', type=int, nargs='?',
help=
'''Total number of desired clusters
(default: '%(default)s')
''',
default=2
)
parser.add_argument(
'shift', metavar='CENTROID_SHIFT', type=float, nargs='?',
help=
'''Centroid shift threshold. If cluster centroids move less-than this value, clustering is finished
(default: '%(default)s')
''',
default=1.0
)
parser.add_argument(
'loops', metavar='LOOP_COUNT', type=int, nargs='?',
help=
'''Maximum count of loops to perform clustering
(default: '%(default)s')
''',
default=3
)
parser.add_argument(
'--data', '-d', metavar='X,Y', type=point, nargs='*',
help=
'''A list of data points separated by spaces as: x,y x,y x,y ...
(default: '%(default)s')
''',
default=[(1.0, 2.0), (2.0, 3.0), (2.0, 2.0), (5.0, 6.0), (6.0, 7.0), (6.0, 8.0), (7.0, 11.0), (1.0, 1.0)]
)
parser.add_argument(
'--seeds', '--seed', '-s', metavar='X,Y', type=point, nargs='*',
help=
'''A list of seed points separated by spaces as: x,y x,y x,y ...
Number of seeds provided must match CLUSTER_COUNT, or else CLUSTER_COUNT will be overriden.
''',
)
parser.add_argument(
'--silent', action='store_true',
help=
'''When this flag is set, scatter plot visualizations will not be shown
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--verbose', '-v', action='store_true',
help=
'''When this flag is set, cluster members will be shown in output
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--random', '-r', action='store_true',
help=
'''When this flag is set, data will be randomly generated
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--radius', metavar='RADIUS', type=float, nargs='?',
help=
'''Initial radius to use for clusters
(default: '%(default)s')
''',
default=None
)
parser.add_argument(
'--lock-radius', '-l', action='store_true',
help=
'''When this flag is set, centroid radius will not be recalculated
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--file', '-f', metavar='FILE_PATH', nargs='?', type=open,
help=
'''Optionally provide file for data to be read from. Each point must be on it\'s own line with format x,y
''',
)
return parser
################################################################################
# Helper Functions
################################################################################
# ==============================================================================
def point(arg):
"""
Helper function for parsing x,y points provided through argparse CLI
:param arg: A single argument passed to an option or positional argument
:return: A tuple (x, y) representing a data point
"""
try:
x, y = literal_eval(arg)
return float(x), float(y) # Cast all point values to float
except:
raise argparse.ArgumentTypeError("Please provide data points in x,y format")
def random_data():
"""
Generates random data points for testing clustering
:return: A list of random data point tuples [(1, 1), (2, 4), ...]
"""
data_size = random.randint(50, random.randint(100, 200))
data = []
for x in range(0, data_size):
data.append((random.randint(0, 100), random.randint(0, 100)))
return data
def round_points(points, precision=4):
"""
Rounds all points in a list to a given decimal place
:param points: A list of data points to round to requested decimal place
:param precision: The decimal place to round to
:return: A list of points where (x, y) has been rounded to match requested precision value
"""
points = [(round(x, precision), round(y, precision)) for x,y in points]
return points
################################################################################
# K-means Clustering
################################################################################
# ==============================================================================
def select_seeds(data):
"""
Randomly select N seeds where N is the number of clusters requested through the CLI
:param data: A list of data points [(0, 1), (2, 2), (1, 4), ...]
:return: Dictionary of {seeds: radius}; For example {(2, 2): 5.0, (1, 4): 5.0}
"""
assert(len(data) > context.clusters)
x, y = zip(*data)
seeds = {} # Store seeds in a dictionary<seed, radius>
for i in range(0, context.clusters):
while True:
new_seed = data[random.randint(0, len(data) - 1)]
if new_seed not in seeds:
break
seeds[new_seed] = i if not context.radius else context.radius
if context.radius:
# An initial radius was provided and applied. Use it.
return seeds
else:
# No initial radius was provided, so calculate one
return update_clusters(seeds)
def points_average(data):
"""
Finds average (x, y) for points in data list [(x, y), (x, y), ...]
Used for updating cluster centroid positions
:param data: List [(x, y), (x, y), ...]
:return: An average (x, y) position for the list of points
"""
x, y = 0, 0
for pair in data:
x += pair[0]
y += pair[1]
x = float(x / len(data))
y = float(y / len(data))
return x, y
def update_clusters(seeds, clusters=None):
"""
Seeds {(x, y), radius} for clusters must be provided
If no clusters {(x, y), [members, ...]} are provided, initialize cluster radius given seeds
If clusters are provided, update centroids and radius
:param seeds: Dictionary of {cluster_seed: radius}; Example {(x, y), radius, (x, y): radius, ...}
:param clusters: Dictionary of {cluster_seed: member_list}; Example {(x, y): [(x, y), (x, y), ...], ...}
:return: Cluster seeds dictionary with updates positions and radius values
"""
radius = sys.maxsize
new_seeds = dict()
if clusters is None: # If we only provided seeds, initialize their radius
for seed in seeds:
for other_seed in seeds.copy():
if other_seed == seed:
continue
dist = math.dist(seed, other_seed)
# Track the smallest distance between 2 centroids
radius = dist if dist < radius else radius
# Update all seeds to the initial cluster radius
radius /= 2
for seed in seeds:
seeds[seed] = radius
else:
# Update centroid positions for clusters if they were provided
for centroid, members in clusters.items():
cluster_data = set(members) | {centroid}
avgX, avgY = points_average(cluster_data)
new_seeds[tuple((avgX, avgY))] = seeds[centroid]
# If we have passed the CLI flag to lock cluster radius, return new seeds without updating radius
# + If we have not passed the -l flag, update cluster radius
seeds = new_seeds if context.lock_radius else update_clusters(new_seeds)
return seeds
def cluster_data(data, seeds):
"""
Runs K-Means clustering on some provided data using a dictionary of cluster seeds {centroid: radius}
:param data: A list of data points to cluster [(x, y), (x, y), ...]
:param seeds: Dictionary of cluster centroid positions and radius {centroid: radius}
:return: Dictionary of final clusters found {centroid: member_list, ...} and updated seeds dictionary
"""
outliers = set()
clusters = {}
for seed in seeds: # Initialize empty clusters for each seed
# If centroid is a data point, it is also a member of the cluster
clusters[seed] = [seed] if seed in data else []
print(f'Updating cluster membership using cluster seeds, radius: ')
for seed, radius in seeds.items():
print(f'\t(({seed[0]:.4f}, {seed[1]:.4f}), {radius:.4f})')
# For each point, calculate the distance from all seeds
for point in data:
for seed, radius in seeds.items():
if point is seed: # Do not check for distance(point, point)
continue
dist = math.dist(point, seed)
if dist <= radius: # If the distance from any cluster is within range, add point to the cluster
# This print statement is noisy, but it can be uncommented to see output for each new cluster member
# print(f'{point} added to cluster {seed}\n\tDistance ({dist}) is within radius ({radius})')
# Take union of point and cluster data
clusters.update({seed: list(set(clusters[seed]) | set([point]))})
# Initialize outliers using difference between sets
outliers = set(data) - (set(chain(*clusters.values())) | set(clusters.keys()))
print(f'Outliers present: {outliers}')
return clusters, seeds
def show_clusters(data, seeds, plot, show=True):
"""
Shows clusters using matplotlib
:param data: Data points to draw on the scatter plot
:param seeds: Cluster seed dictionary {centroid: radius, ...}
:param plot: The subplot to plot data on
:param show: Toggles displaying a window for the plot.
Allows two plots to be drawn on the same subplot and then shown together using a subsequent call to plt.show()
"""
dataX, dataY = zip(*data)
plot.set_aspect(1. / plot.get_data_ratio())
plot.scatter(dataX, dataY, c='k')
# Draw circles for clusters
cs = []
while len(cs) < context.clusters: # Ensure we have enough colors to display all clusters
cs.extend(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
for seed, radius, c in zip(seeds.keys(), seeds.values(), cs):
plot.scatter(seed[0], seed[1], color=c)
circle = plt.Circle(seed, radius, alpha=0.25, color=c)
plot.add_patch(circle)
plot.grid()
if show:
print(f'Close window to update centroid positions and re-cluster data...')
plt.show()
def print_cluster_info(initial_clusters, seeds, centroid_diff):
"""
Outputs some information on clusters after each iteration
:param initial_clusters: The clusters as they were before reclustering
:param seeds: The new seeds dictionary {centroid: radius, ...}
:param centroid_diff: List of difference in centroid positions for each cluster
"""
for initial_point, initial_radius, updated, radius, dist in\
zip(initial_clusters.keys(), initial_clusters.values(), seeds.keys(), seeds.values(), centroid_diff):
print(f'Initial cluster at ({initial_point[0]:.4f}, {initial_point[1]:.4f}) '
f'moved to ({updated[0]:.4f}, {updated[1]:.4f})'
f'\n\tTotal shift: {dist:.4f}'
f'\n\tFinal radius: {radius:.4f}')
if initial_radius != radius:
print(f'\tInitial radius: {initial_radius:.4f}')
################################################################################
# Main
################################################################################
# ==============================================================================
def main(args: List[str]):
parser = init_parser()
global context
context = parser.parse_args(args[1:])
if context.file: # If a file was provided, use that data instead
context.data = [literal_eval(line.rstrip()) for line in context.file]
context.data = [(float(x), float(y)) for x, y in context.data]
elif context.random: # If random flag was set, randomly generate some data
print("TODO: Randomly generate data")
context.data = random_data()
print(
f'Finding K-means clusters for given data {context.data}\n'
f'\tUsing {context.clusters} clusters, {context.shift} max centroid shift, and {context.loops} iterations'
)
seeds = {}
if context.seeds: # Enforce CLUSTER_COUNT matching initial number of seeds
context.clusters = len(context.seeds)
seeds = update_clusters(dict.fromkeys(context.seeds, 0))
else: # Select 2 random seeds once, before we enter clustering loop
seeds = select_seeds(context.data)
# Save a copy of the initial clusters to show comparison at the end
initial_clusters = seeds.copy()
for loop in range(0, context.loops):
print(f'\nClustering iteration {loop}')
plt.title(f'Cluster iteration {loop}')
# Check distance from all points to seed
clusters, seeds = cluster_data(context.data, seeds)
if loop > 0: # The initial graph has no centroid shift to print
# If we are on any iteration beyond the first, print updated cluster information
# + The first iteration shows initial data, since it has no updated data yet
print_cluster_info(prev_centroids, seeds, centroid_diff)
if context.verbose:
print(f'Cluster members:')
for member in [f'{np.round(cent, 4)}: {members}' for cent, members in clusters.items()]:
print(member)
elif loop == 0 and not context.silent:
# If we are on the first iteration, show the initial data provided through CLI
print(
f'Showing initial data with {context.clusters} clusters '
f'given seed points {round_points(seeds.keys())}'
)
# Show the plot for every iteration if it is not suppressed by the CLI --silent flag
if not context.silent:
show_clusters(context.data, seeds, plt.subplot())
# Update centroids for new cluster data
prev_centroids = seeds.copy()
seeds = update_clusters(seeds, clusters)
print(
f'\nUpdated clusters ({round_points(prev_centroids.keys())}) '
f'with new centroids {round_points(seeds.keys())}'
)
# Find the difference in position for all centroids using their previous and current positions
centroid_diff = [round(math.dist(prev, curr), 4) for prev, curr in
list(zip(prev_centroids.keys(), seeds.keys()))]
print(f'New centroids {round_points(seeds.keys())} shifted {centroid_diff} respectively')
# If any centroid has moved more than context.shift, the clusters are not stable
stable = not any((diff > context.shift for diff in centroid_diff))
if stable: # If centroid shift is not > context.shift, centroids have not changed
break # Stop re-clustering process and show final result
print("\n\nShowing final cluster result...")
centroid_diff = [round(math.dist(prev, curr), 4) for prev, curr in
list(zip(initial_clusters.keys(), seeds.keys()))]
print_cluster_info(initial_clusters, seeds, centroid_diff)
# If the clusters reached a point where they were stable, show output to warn
if stable:
print(
f'\nStopping...\n'
f'Cluster centroids have not shifted at least {context.shift}, clusters are stable'
)
if not context.silent:
# Create a side-by-side subplot to compare first iteration with final clustering results
print(f'Close window to exit...')
f, arr = plt.subplots(1, 2)
arr[0].set_title(f'Cluster {0} (Initial result)')
show_clusters(context.data, initial_clusters, arr[0], False)
arr[1].set_title(f'Cluster {loop} (Final result)')
show_clusters(context.data, seeds, arr[1], False)
plt.show()
if __name__ == "__main__":
sys.exit(main(sys.argv))