klips/python/markov-model/markov-model.py

482 lines
18 KiB
Python

################################################################################
# Author: Shaun Reed #
# About: HMM implementation to calculate most probable path for sequence #
# Contact: shaunrd0@gmail.com | URL: www.shaunreed.com | GitHub: shaunrd0 #
################################################################################
from matplotlib import pyplot as plt
from typing import List
import argparse
import itertools
import json
import networkx as nx
import numpy as np
import random
import sys
################################################################################
# CLI Argument Parser
################################################################################
# ==============================================================================
def init_parser():
parser = argparse.ArgumentParser(
description='Calculates most probable path of HMM given an observation sequence',
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'sequence', metavar='OBSERVATION_SEQUENCE', nargs='*',
help=
'''An observation sequence to calculate the most probable path
(default: '%(default)s')
''',
default=['A', 'B', 'D', 'C']
)
parser.add_argument(
'--nodes', '-n', metavar='GRAPH_NODE_COUNT',type=int, nargs='?',
help=
'''The total number of node states in the HMM graph
(default: '%(default)s')
''',
default=4
)
parser.add_argument(
'--edges', '-e', metavar='GRAPH_EDGE_COUNT',type=int, nargs='?',
help=
'''The total number of edges in the HMM graph
(default: '%(default)s')
''',
default=8
)
parser.add_argument(
'--show-all', action='store_true',
help=
'''When this flag is set, all path probabilities and their calculations will be output
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--interactive', action='store_true',
help=
'''Allow taking input to update matrices with triple (row, col, value)
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--silent', action='store_true',
help=
'''When this flag is set, final graph will not be shown
(default: '%(default)s')
''',
default=False
)
parser.add_argument(
'--file', '-f', metavar='FILE_PATH', nargs='?', type=open,
help=
'''Optionally provide file for data to be read from. Each point must be on it\'s own line with format x,y
''',
)
return parser
################################################################################
# Helper Functions
################################################################################
# ==============================================================================
def parse_file():
"""
Validates keys in JSON file and updates CLI input context
Initializes a MultiDiGraph object using input data model_graph
Initializes a matrix of emission probabilities emission_matrix
:return: model_graph, emission_matrix
"""
# Load the JSON input file, validate keys
file_data = json.load(context['file'])
for key in file_data:
if key == "transition_matrix" or key == "emission_matrix":
continue
assert key in context
# Update the CLI context with JSON input
context.update(file_data)
model_graph = nx.MultiDiGraph(build_graph(np.array(file_data['transition_matrix'])))
emission_matrix = np.array(file_data['emission_matrix'])
return model_graph, emission_matrix
def random_emission():
"""
Initialize an emission matrix size SxE
Where S is number of states and E is number of emissions
:return: Initialized emission_matrix
"""
emission_matrix = np.zeros((context["nodes"], len(set(context["sequence"]))))
shape = emission_matrix.shape
for row in range(0, shape[0]):
for col in range(0, shape[1]):
# Let random number swing below 0 to increase chance of nodes not emitting
emit_prob = round(random.uniform(-0.25, 1.0), 2)
emit_prob = 0.0 if emit_prob < 0.0 else emit_prob
emission_matrix[row][col] = emit_prob
return emission_matrix
def random_graph(nodes, edges=2):
"""
Create a random graph represented as a list [(from_node, to_node, {'weight': edge_weight}), ...]
Networkx can use this list in constructors for graph objects
:param nodes: The number of nodes in the graph
:param edges: The number of edges connecting nodes in the graph
:return: A list [(from_node, to_node, {'weight': edge_weight}), ...]
"""
# By default, make twice as many edges as there are nodes
edges *= nodes if edges == 2 else 1
r_graph = []
for x in range(0, edges):
while True:
new_edge = (
random.randint(0, nodes - 1), # Randomly select a from_node index
random.randint(0, nodes - 1), # Randomly select a to_node index
{
# Randomly set an edge weight between from_node and to_node
'weight':
round(random.uniform(0.0, 1.0), 2)
}
)
if not any((new_edge[0], new_edge[1]) == (a, b) for a, b, w in r_graph):
break
r_graph.append(new_edge)
return r_graph
def build_graph(t_matrix):
"""
Converts a transition matrix to a list of edges and weights
This list can then be passed to NetworkX graph constructors
:param t_matrix: The transition matrix to build the graph from
:return: A list [(from_node, to_node, {'weight': edge_weight}), ...]
"""
n_graph = []
shape = t_matrix.shape
for row in range(0, shape[0]):
for col in range(0, shape[1]):
if t_matrix[row][col] <= 0.0:
continue
new_edge = (row, col, {'weight': t_matrix[row][col]})
n_graph.append(new_edge)
return n_graph
def transition_matrix(graph: nx.MultiDiGraph):
"""
Build a transition matrix from a Networkx MultiDiGraph object
:param graph: An initialized MultiDiGraph graph object
:return: An initialized transition matrix with shape (NODE_COUNT, NODE_COUNT)
"""
# Initialize a matrix of zeros with size ExE where E is total number of states (nodes)
t_matrix = np.zeros((context["nodes"], context["nodes"]))
# Build matrices from iterating over the graph
for a, b, weight in graph.edges(data='weight'):
t_matrix[a][b] = weight
if context["show_all"]:
print(f'{a}->{b}: {weight}')
return t_matrix
def make_emission_dict(emission_matrix):
"""
Create a dictionary that maps to index keys for each emission. emission_keys
Create a dictionary that maps to a list of emitting nodes for each emission. emission_dict
:param emission_matrix: An emission_matrix size NxE
Where N is the number of nodes (states) and E is the number of emissions
:return: emission_dict, emission_keys
"""
emission_dict = {}
for emission in sorted(set(context["sequence"])):
emission_dict[emission] = []
emission_keys = dict.fromkeys(emission_dict.keys())
# Initialize emission_dict to store a list of all nodes that emit the key value
shape = emission_matrix.shape
i = 0
for key in emission_dict.keys():
for row in range(0, shape[0]):
if emission_matrix[row][i] > 0:
emission_dict[key].append(row)
emission_keys[key] = i
i += 1
return emission_dict, emission_keys
def int_input(prompt):
"""
Forces integer input. Retries and warns if bogus values are entered.
:param prompt: The initial prompt message to show for input
:return: The integer input by the user at runtime
"""
while True:
try:
value = int(input(prompt))
break
except ValueError:
print("Please enter an integer value")
return value
def triple_input(matrix):
"""
Takes 3 integer input, validates it makes sense for the selected matrix
If row or column selected is outside the limits of the matrix, warn and retry input until valid
:param matrix: The matrix to use for input validation
:return: The validated input
"""
row = int_input("Row: ")
col = int_input("Col: ")
value = int_input("Value: ")
row, col = check_input(row, col, matrix)
return row, col, value
def check_input(row, col, matrix):
"""
Checks that row, col input values are within the bounds of matrix
If valid values are passed initially, no additional prompts are made.
Retries input until valid values are input.
:param row: The row index input by the user
:param col: The col index input by the user
:param matrix: The matrix to use for input validation
:return: The validated input for row and column index
"""
while row > matrix.shape[0] - 1:
print(f'{row} is too large for transition matrix of shape {matrix.shape}')
row = int_input("Row : ")
while col > matrix.shape[1] - 1:
print(f'{col} is too large for transition matrix of shape {matrix.shape}')
col = int_input("Col: ")
return row, col
################################################################################
# Hidden Markov Model
################################################################################
# ==============================================================================
def find_paths(emission_dict, t_matrix):
"""
Find all possible paths for an emission sequence
:param emission_dict: A dictionary of emitters for emissions {emission_1: [0, 1], emission_2: [1, 3], ...}
:param t_matrix: A transition matrix size NxN where N is the total number of nodes in the graph
:return: A list of validated paths for the emission given through the CLI
"""
paths = []
for emission in context["sequence"]:
paths.append(emission_dict[emission])
# Take the cartesian product of the emitting nodes to get a list of all possible paths
# + Return only the paths which have > 0 probability given the transition matrix
return validate_paths(list(itertools.product(*paths)), t_matrix)
def validate_paths(path_list: list, t_matrix):
"""
Checks all paths in path_list [[0, 1, 2, 3], [0, 1, 1, 2], ...]
If the transition matrix t_matrix indicates any node in a path can't reach the next node in path
The path can't happen given our graph. Remove it from the list of paths.
:param path_list: A list of paths to validate
:param t_matrix: A transition matrix size NxN where N is the total number of nodes in the graph
:return: A list of validated paths [[0, 1, 2, 3], [0, 1, 1, 2], ...]
"""
valid_paths = []
for path in path_list:
valid = True
for step in range(0, len(path) - 1):
current_node = path[step]
# If the transition matrix indicates that the chance to move to next step in path is 0
if t_matrix[current_node][path[step+1]] <= 0.0:
# The path cannot possibly happen. Don't consider it.
valid = False
break
if valid:
# We reached the end of our path without hitting a dead-end. The path is valid.
valid_paths.append(path)
return valid_paths
def find_probability(emission_matrix, t_matrix, emission_keys, valid_paths):
"""
Find probability of paths occurring given our current HMM
Store result in a dictionary {probability: (0, 1, 2, 3), probability_2: (0, 0, 1, 2)}
:param emission_matrix: A matrix of emission probabilities NxE where N is the emitting node and E is the emission
:param t_matrix: A transition matrix NxN where N is the total number of nodes in the graph
:param emission_keys: A dictionary mapping to index values for emissions as E in the emission_matrix
:param valid_paths: A list of valid paths to calculate probability given an emission sequence
:return: A dictionary of {prob: path}; For example {probability: (0, 1, 2, 3), probability_2: (0, 0, 1, 2)}
"""
path_prob = {}
seq = list(context["sequence"])
for path in valid_paths:
calculations = f'Calculating {path}: '
prob = 1.0
for step in range(0, len(path) - 1):
current_node = path[step]
next_node = path[step + 1]
emission_index = emission_keys[seq[step]]
emission_prob = emission_matrix[current_node][emission_index]
transition_prob = t_matrix[current_node][next_node]
calculations += f'({emission_prob:.2f} * {transition_prob:.2f}) * '
prob *= emission_prob * transition_prob
emission_index = emission_keys[seq[step + 1]]
final_emission_prob = emission_matrix[next_node][emission_index]
prob *= final_emission_prob
calculations += f'{final_emission_prob:.2f} = {prob:.6f}'
if prob > 0.0: # Don't keep paths which aren't possible due to emission sequence
path_prob[prob] = path
if context["show_all"]:
print(calculations)
return path_prob
def run_problem(transition_matrix, emission_matrix):
"""
Runs the HMM calculations given a transition_matrix and emission_matrix
:param transition_matrix: A matrix size NxN where N is the total number of nodes and values represent probability
:param emission_matrix: A matrix size NxE where N is total nodes and E is total number of emissions
:return: A dictionary of {probability: path} sorted by probability key from in descending order
"""
# Dictionary of {emission: [emitter, ...]}
emission_dict, emission_keys = make_emission_dict(emission_matrix)
valid_paths = find_paths(emission_dict, transition_matrix)
path_prob = find_probability(emission_matrix, transition_matrix, emission_keys, valid_paths)
result = {key: path_prob[key] for key in dict.fromkeys(sorted(path_prob.keys(), reverse=True))}
print(f'Finding most probable path for given observation sequence: {context["sequence"]}\n'
f'\tTotal nodes in graph: {context["nodes"]}\n'
f'\tTotal edges in graph: {context["edges"]}\n'
f'\tNumber of sequences: {len(set(context["sequence"]))}\n'
f'\tInteractive mode: {context["interactive"]}\n'
f'\tEmitting nodes: {emission_dict}\n'
f'Transition matrix: \n{transition_matrix}\n'
f'Emission matrix: \n{emission_matrix}'
)
return result
def show_result(result):
"""
Prints results from running the HMM calculations
:param result: The result dictionary returned by run_problem()
"""
if len(result) == 0:
print(f'No valid paths found for sequence {context["sequence"]}')
elif context["show_all"]:
print(f'Final paths sorted by probability:')
[print(f'{path} has probability:\t {prob:.6f}') for prob, path in result.items()]
else:
print(f'{list(result.values())[0]} has the highest probability of {list(result.keys())[0]}')
def draw_graph(graph):
"""
Draws the model_graph for the current HMM using NetworkX
:param graph: An initialized MultiDiGraph object with edge weights representing transition probability
"""
# Get a dictionary of {node: position} for drawing the graph
dict_pos = nx.spring_layout(graph)
nx.draw(
graph, dict_pos,
with_labels=True,
node_size=[x * 200 for x in dict(graph.degree).values()],
alpha=1,
arrowstyle="->",
arrowsize=25,
)
# TODO: Fix parallel path weight display
nx.draw_networkx_edge_labels(graph, dict_pos)
plt.show()
################################################################################
# Main
################################################################################
# ==============================================================================
def main(args: List[str]):
parser = init_parser()
global context
context = vars(parser.parse_args(args[1:]))
if context["file"]: # If a file was provided, use that data instead
model_graph, emission_matrix = parse_file()
else:
# If no file was provided, build a random graph with the requested number of nodes and edges
model_graph = nx.MultiDiGraph(random_graph(context["nodes"], context["edges"]))
# Create a random emission matrix
emission_matrix = random_emission()
t_matrix = transition_matrix(model_graph)
result = run_problem(t_matrix, emission_matrix)
show_result(result)
# Draw the graph for a visual example to go with output
if not context["silent"]:
draw_graph(model_graph)
# Unless we are in interactive mode, we're finished. Return.
if not context["interactive"]:
return
# Prompt to update the transition or emission matrix, then rerun problem with new values
print("Choose matrix to update:\n\t1. Transition\n\t2. Emission\n\t3. Both", end='')
choice = input()
if choice == '1':
row, col, value = triple_input(t_matrix)
t_matrix[row][col] = value
elif choice == '2':
row, col, value = triple_input(emission_matrix)
emission_matrix[row][col] = value
elif choice == '3':
print('\nInput for updating transition matrix')
row, col, value = triple_input(t_matrix)
t_matrix[row][col] = value
print('\nInput for updating emission matrix')
row, col, value = triple_input(emission_matrix)
emission_matrix[row][col] = value
result = run_problem(t_matrix, emission_matrix)
show_result(result)
# Draw the graph for a visual example to go with output
if not context["silent"]:
model_graph = nx.MultiDiGraph(build_graph(np.array(t_matrix)))
draw_graph(model_graph)
if __name__ == "__main__":
sys.exit(main(sys.argv))