439 lines
17 KiB
Python
439 lines
17 KiB
Python
################################################################################
|
|
# Author: Shaun Reed #
|
|
# About: K-Means clustering CLI #
|
|
# Contact: shaunrd0@gmail.com | URL: www.shaunreed.com | GitHub: shaunrd0 #
|
|
################################################################################
|
|
|
|
from ast import literal_eval
|
|
from itertools import chain
|
|
from matplotlib import pyplot as plt
|
|
from typing import List
|
|
import argparse
|
|
import math
|
|
import numpy as np
|
|
import random
|
|
import sys
|
|
|
|
|
|
################################################################################
|
|
# CLI Argument Parser
|
|
################################################################################
|
|
|
|
# ==============================================================================
|
|
|
|
def init_parser():
|
|
parser = argparse.ArgumentParser(
|
|
description='K-means clustering program for clustering data read from a file, terminal, or randomly generated',
|
|
formatter_class=argparse.RawTextHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
'clusters', metavar='CLUSTER_COUNT', type=int, nargs='?',
|
|
help=
|
|
'''Total number of desired clusters
|
|
(default: '%(default)s')
|
|
''',
|
|
default=2
|
|
)
|
|
|
|
parser.add_argument(
|
|
'shift', metavar='CENTROID_SHIFT', type=float, nargs='?',
|
|
help=
|
|
'''Centroid shift threshold. If cluster centroids move less-than this value, clustering is finished
|
|
(default: '%(default)s')
|
|
''',
|
|
default=1.0
|
|
)
|
|
|
|
parser.add_argument(
|
|
'loops', metavar='LOOP_COUNT', type=int, nargs='?',
|
|
help=
|
|
'''Maximum count of loops to perform clustering
|
|
(default: '%(default)s')
|
|
''',
|
|
default=3
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--data', '-d', metavar='X,Y', type=point, nargs='*',
|
|
help=
|
|
'''A list of data points separated by spaces as: x,y x,y x,y ...
|
|
(default: '%(default)s')
|
|
''',
|
|
default=[(1.0, 2.0), (2.0, 3.0), (2.0, 2.0), (5.0, 6.0), (6.0, 7.0), (6.0, 8.0), (7.0, 11.0), (1.0, 1.0)]
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--seeds', '--seed', '-s', metavar='X,Y', type=point, nargs='*',
|
|
help=
|
|
'''A list of seed points separated by spaces as: x,y x,y x,y ...
|
|
Number of seeds provided must match CLUSTER_COUNT, or else CLUSTER_COUNT will be overriden.
|
|
''',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--silent', action='store_true',
|
|
help=
|
|
'''When this flag is set, scatter plot visualizations will not be shown
|
|
(default: '%(default)s')
|
|
''',
|
|
default=False
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--verbose', '-v', action='store_true',
|
|
help=
|
|
'''When this flag is set, cluster members will be shown in output
|
|
(default: '%(default)s')
|
|
''',
|
|
default=False
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--random', '-r', action='store_true',
|
|
help=
|
|
'''When this flag is set, data will be randomly generated
|
|
(default: '%(default)s')
|
|
''',
|
|
default=False
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--radius', metavar='RADIUS', type=float, nargs='?',
|
|
help=
|
|
'''Initial radius to use for clusters
|
|
(default: '%(default)s')
|
|
''',
|
|
default=None
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--lock-radius', '-l', action='store_true',
|
|
help=
|
|
'''When this flag is set, centroid radius will not be recalculated
|
|
(default: '%(default)s')
|
|
''',
|
|
default=False
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--file', '-f', metavar='FILE_PATH', nargs='?', type=open,
|
|
help=
|
|
'''Optionally provide file for data to be read from. Each point must be on it\'s own line with format x,y
|
|
''',
|
|
)
|
|
return parser
|
|
|
|
|
|
################################################################################
|
|
# Helper Functions
|
|
################################################################################
|
|
|
|
# ==============================================================================
|
|
|
|
def point(arg):
|
|
"""
|
|
Helper function for parsing x,y points provided through argparse CLI
|
|
|
|
:param arg: A single argument passed to an option or positional argument
|
|
:return: A tuple (x, y) representing a data point
|
|
"""
|
|
try:
|
|
x, y = literal_eval(arg)
|
|
return float(x), float(y) # Cast all point values to float
|
|
except:
|
|
raise argparse.ArgumentTypeError("Please provide data points in x,y format")
|
|
|
|
|
|
def random_data():
|
|
"""
|
|
Generates random data points for testing clustering
|
|
|
|
:return: A list of random data point tuples [(1, 1), (2, 4), ...]
|
|
"""
|
|
data_size = random.randint(50, random.randint(100, 200))
|
|
data = []
|
|
for x in range(0, data_size):
|
|
data.append((random.randint(0, 100), random.randint(0, 100)))
|
|
return data
|
|
|
|
|
|
def round_points(points, precision=4):
|
|
"""
|
|
Rounds all points in a list to a given decimal place
|
|
|
|
:param points: A list of data points to round to requested decimal place
|
|
:param precision: The decimal place to round to
|
|
:return: A list of points where (x, y) has been rounded to match requested precision value
|
|
"""
|
|
points = [(round(x, precision), round(y, precision)) for x,y in points]
|
|
return points
|
|
|
|
|
|
################################################################################
|
|
# K-means Clustering
|
|
################################################################################
|
|
|
|
# ==============================================================================
|
|
|
|
def select_seeds(data):
|
|
"""
|
|
Randomly select N seeds where N is the number of clusters requested through the CLI
|
|
|
|
:param data: A list of data points [(0, 1), (2, 2), (1, 4), ...]
|
|
:return: Dictionary of {seeds: radius}; For example {(2, 2): 5.0, (1, 4): 5.0}
|
|
"""
|
|
assert(len(data) > context.clusters)
|
|
x, y = zip(*data)
|
|
seeds = {} # Store seeds in a dictionary<seed, radius>
|
|
for i in range(0, context.clusters):
|
|
while True:
|
|
new_seed = data[random.randint(0, len(data) - 1)]
|
|
if new_seed not in seeds:
|
|
break
|
|
seeds[new_seed] = i if not context.radius else context.radius
|
|
|
|
if context.radius:
|
|
# An initial radius was provided and applied. Use it.
|
|
return seeds
|
|
else:
|
|
# No initial radius was provided, so calculate one
|
|
return update_clusters(seeds)
|
|
|
|
|
|
def points_average(data):
|
|
"""
|
|
Finds average (x, y) for points in data list [(x, y), (x, y), ...]
|
|
Used for updating cluster centroid positions
|
|
|
|
:param data: List [(x, y), (x, y), ...]
|
|
:return: An average (x, y) position for the list of points
|
|
"""
|
|
x, y = 0, 0
|
|
for pair in data:
|
|
x += pair[0]
|
|
y += pair[1]
|
|
x = float(x / len(data))
|
|
y = float(y / len(data))
|
|
return x, y
|
|
|
|
|
|
def update_clusters(seeds, clusters=None):
|
|
"""
|
|
Seeds {(x, y), radius} for clusters must be provided
|
|
If no clusters {(x, y), [members, ...]} are provided, initialize cluster radius given seeds
|
|
If clusters are provided, update centroids and radius
|
|
|
|
:param seeds: Dictionary of {cluster_seed: radius}; Example {(x, y), radius, (x, y): radius, ...}
|
|
:param clusters: Dictionary of {cluster_seed: member_list}; Example {(x, y): [(x, y), (x, y), ...], ...}
|
|
:return: Cluster seeds dictionary with updates positions and radius values
|
|
"""
|
|
radius = sys.maxsize
|
|
new_seeds = dict()
|
|
if clusters is None: # If we only provided seeds, initialize their radius
|
|
for seed in seeds:
|
|
for other_seed in seeds.copy():
|
|
if other_seed == seed:
|
|
continue
|
|
dist = math.dist(seed, other_seed)
|
|
# Track the smallest distance between 2 centroids
|
|
radius = dist if dist < radius else radius
|
|
|
|
# Update all seeds to the initial cluster radius
|
|
radius /= 2
|
|
for seed in seeds:
|
|
seeds[seed] = radius
|
|
else:
|
|
# Update centroid positions for clusters if they were provided
|
|
for centroid, members in clusters.items():
|
|
cluster_data = set(members) | {centroid}
|
|
avgX, avgY = points_average(cluster_data)
|
|
new_seeds[tuple((avgX, avgY))] = seeds[centroid]
|
|
# If we have passed the CLI flag to lock cluster radius, return new seeds without updating radius
|
|
# + If we have not passed the -l flag, update cluster radius
|
|
seeds = new_seeds if context.lock_radius else update_clusters(new_seeds)
|
|
return seeds
|
|
|
|
|
|
def cluster_data(data, seeds):
|
|
"""
|
|
Runs K-Means clustering on some provided data using a dictionary of cluster seeds {centroid: radius}
|
|
|
|
:param data: A list of data points to cluster [(x, y), (x, y), ...]
|
|
:param seeds: Dictionary of cluster centroid positions and radius {centroid: radius}
|
|
:return: Dictionary of final clusters found {centroid: member_list, ...} and updated seeds dictionary
|
|
"""
|
|
outliers = set()
|
|
clusters = {}
|
|
for seed in seeds: # Initialize empty clusters for each seed
|
|
# If centroid is a data point, it is also a member of the cluster
|
|
clusters[seed] = [seed] if seed in data else []
|
|
|
|
print(f'Updating cluster membership using cluster seeds, radius: ')
|
|
for seed, radius in seeds.items():
|
|
print(f'\t(({seed[0]:.4f}, {seed[1]:.4f}), {radius:.4f})')
|
|
|
|
# For each point, calculate the distance from all seeds
|
|
for point in data:
|
|
for seed, radius in seeds.items():
|
|
if point is seed: # Do not check for distance(point, point)
|
|
continue
|
|
dist = math.dist(point, seed)
|
|
if dist <= radius: # If the distance from any cluster is within range, add point to the cluster
|
|
# This print statement is noisy, but it can be uncommented to see output for each new cluster member
|
|
# print(f'{point} added to cluster {seed}\n\tDistance ({dist}) is within radius ({radius})')
|
|
# Take union of point and cluster data
|
|
clusters.update({seed: list(set(clusters[seed]) | set([point]))})
|
|
|
|
# Initialize outliers using difference between sets
|
|
outliers = set(data) - (set(chain(*clusters.values())) | set(clusters.keys()))
|
|
print(f'Outliers present: {outliers}')
|
|
return clusters, seeds
|
|
|
|
|
|
def show_clusters(data, seeds, plot, show=True):
|
|
"""
|
|
Shows clusters using matplotlib
|
|
|
|
:param data: Data points to draw on the scatter plot
|
|
:param seeds: Cluster seed dictionary {centroid: radius, ...}
|
|
:param plot: The subplot to plot data on
|
|
:param show: Toggles displaying a window for the plot.
|
|
Allows two plots to be drawn on the same subplot and then shown together using a subsequent call to plt.show()
|
|
"""
|
|
dataX, dataY = zip(*data)
|
|
plot.set_aspect(1. / plot.get_data_ratio())
|
|
|
|
plot.scatter(dataX, dataY, c='k')
|
|
# Draw circles for clusters
|
|
cs = []
|
|
while len(cs) < context.clusters: # Ensure we have enough colors to display all clusters
|
|
cs.extend(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
|
|
for seed, radius, c in zip(seeds.keys(), seeds.values(), cs):
|
|
plot.scatter(seed[0], seed[1], color=c)
|
|
circle = plt.Circle(seed, radius, alpha=0.25, color=c)
|
|
plot.add_patch(circle)
|
|
|
|
plot.grid()
|
|
if show:
|
|
print(f'Close window to update centroid positions and re-cluster data...')
|
|
plt.show()
|
|
|
|
|
|
def print_cluster_info(initial_clusters, seeds, centroid_diff):
|
|
"""
|
|
Outputs some information on clusters after each iteration
|
|
|
|
:param initial_clusters: The clusters as they were before reclustering
|
|
:param seeds: The new seeds dictionary {centroid: radius, ...}
|
|
:param centroid_diff: List of difference in centroid positions for each cluster
|
|
"""
|
|
for initial_point, initial_radius, updated, radius, dist in\
|
|
zip(initial_clusters.keys(), initial_clusters.values(), seeds.keys(), seeds.values(), centroid_diff):
|
|
print(f'Initial cluster at ({initial_point[0]:.4f}, {initial_point[1]:.4f}) '
|
|
f'moved to ({updated[0]:.4f}, {updated[1]:.4f})'
|
|
f'\n\tTotal shift: {dist:.4f}'
|
|
f'\n\tFinal radius: {radius:.4f}')
|
|
if initial_radius != radius:
|
|
print(f'\tInitial radius: {initial_radius:.4f}')
|
|
|
|
|
|
################################################################################
|
|
# Main
|
|
################################################################################
|
|
|
|
# ==============================================================================
|
|
|
|
def main(args: List[str]):
|
|
parser = init_parser()
|
|
global context
|
|
context = parser.parse_args(args[1:])
|
|
if context.file: # If a file was provided, use that data instead
|
|
context.data = [literal_eval(line.rstrip()) for line in context.file]
|
|
context.data = [(float(x), float(y)) for x, y in context.data]
|
|
elif context.random: # If random flag was set, randomly generate some data
|
|
print("TODO: Randomly generate data")
|
|
context.data = random_data()
|
|
|
|
print(
|
|
f'Finding K-means clusters for given data {context.data}\n'
|
|
f'\tUsing {context.clusters} clusters, {context.shift} max centroid shift, and {context.loops} iterations'
|
|
)
|
|
|
|
seeds = {}
|
|
if context.seeds: # Enforce CLUSTER_COUNT matching initial number of seeds
|
|
context.clusters = len(context.seeds)
|
|
seeds = update_clusters(dict.fromkeys(context.seeds, 0))
|
|
else: # Select 2 random seeds once, before we enter clustering loop
|
|
seeds = select_seeds(context.data)
|
|
|
|
# Save a copy of the initial clusters to show comparison at the end
|
|
initial_clusters = seeds.copy()
|
|
for loop in range(0, context.loops):
|
|
print(f'\nClustering iteration {loop}')
|
|
plt.title(f'Cluster iteration {loop}')
|
|
# Check distance from all points to seed
|
|
clusters, seeds = cluster_data(context.data, seeds)
|
|
if loop > 0: # The initial graph has no centroid shift to print
|
|
# If we are on any iteration beyond the first, print updated cluster information
|
|
# + The first iteration shows initial data, since it has no updated data yet
|
|
print_cluster_info(prev_centroids, seeds, centroid_diff)
|
|
if context.verbose:
|
|
print(f'Cluster members:')
|
|
for member in [f'{np.round(cent, 4)}: {members}' for cent, members in clusters.items()]:
|
|
print(member)
|
|
elif loop == 0 and not context.silent:
|
|
# If we are on the first iteration, show the initial data provided through CLI
|
|
print(
|
|
f'Showing initial data with {context.clusters} clusters '
|
|
f'given seed points {round_points(seeds.keys())}'
|
|
)
|
|
|
|
# Show the plot for every iteration if it is not suppressed by the CLI --silent flag
|
|
if not context.silent:
|
|
show_clusters(context.data, seeds, plt.subplot())
|
|
|
|
# Update centroids for new cluster data
|
|
prev_centroids = seeds.copy()
|
|
seeds = update_clusters(seeds, clusters)
|
|
print(
|
|
f'\nUpdated clusters ({round_points(prev_centroids.keys())}) '
|
|
f'with new centroids {round_points(seeds.keys())}'
|
|
)
|
|
|
|
# Find the difference in position for all centroids using their previous and current positions
|
|
centroid_diff = [round(math.dist(prev, curr), 4) for prev, curr in
|
|
list(zip(prev_centroids.keys(), seeds.keys()))]
|
|
print(f'New centroids {round_points(seeds.keys())} shifted {centroid_diff} respectively')
|
|
|
|
# If any centroid has moved more than context.shift, the clusters are not stable
|
|
stable = not any((diff > context.shift for diff in centroid_diff))
|
|
if stable: # If centroid shift is not > context.shift, centroids have not changed
|
|
break # Stop re-clustering process and show final result
|
|
|
|
print("\n\nShowing final cluster result...")
|
|
centroid_diff = [round(math.dist(prev, curr), 4) for prev, curr in
|
|
list(zip(initial_clusters.keys(), seeds.keys()))]
|
|
print_cluster_info(initial_clusters, seeds, centroid_diff)
|
|
|
|
# If the clusters reached a point where they were stable, show output to warn
|
|
if stable:
|
|
print(
|
|
f'\nStopping...\n'
|
|
f'Cluster centroids have not shifted at least {context.shift}, clusters are stable'
|
|
)
|
|
|
|
if not context.silent:
|
|
# Create a side-by-side subplot to compare first iteration with final clustering results
|
|
print(f'Close window to exit...')
|
|
f, arr = plt.subplots(1, 2)
|
|
arr[0].set_title(f'Cluster {0} (Initial result)')
|
|
show_clusters(context.data, initial_clusters, arr[0], False)
|
|
arr[1].set_title(f'Cluster {loop} (Final result)')
|
|
show_clusters(context.data, seeds, arr[1], False)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(sys.argv))
|