################################################################################ # Author: Shaun Reed # # About: HMM implementation to calculate most probable path for sequence # # Contact: shaunrd0@gmail.com | URL: www.shaunreed.com | GitHub: shaunrd0 # ################################################################################ from matplotlib import pyplot as plt from typing import List import argparse import itertools import json import networkx as nx import numpy as np import random import sys ################################################################################ # CLI Argument Parser ################################################################################ # ============================================================================== def init_parser(): parser = argparse.ArgumentParser( description='Calculates most probable path of HMM given an observation sequence', formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( 'sequence', metavar='OBSERVATION_SEQUENCE', nargs='*', help= '''An observation sequence to calculate the most probable path (default: '%(default)s') ''', default=['A', 'B', 'D', 'C'] ) parser.add_argument( '--nodes', '-n', metavar='GRAPH_NODE_COUNT',type=int, nargs='?', help= '''The total number of node states in the HMM graph (default: '%(default)s') ''', default=4 ) parser.add_argument( '--edges', '-e', metavar='GRAPH_EDGE_COUNT',type=int, nargs='?', help= '''The total number of edges in the HMM graph (default: '%(default)s') ''', default=8 ) parser.add_argument( '--show-all', action='store_true', help= '''When this flag is set, all path probabilities and their calculations will be output (default: '%(default)s') ''', default=False ) parser.add_argument( '--interactive', action='store_true', help= '''Allow taking input to update matrices with triple (row, col, value) (default: '%(default)s') ''', default=False ) parser.add_argument( '--silent', action='store_true', help= '''When this flag is set, final graph will not be shown (default: '%(default)s') ''', default=False ) parser.add_argument( '--file', '-f', metavar='FILE_PATH', nargs='?', type=open, help= '''Optionally provide file for data to be read from. Each point must be on it\'s own line with format x,y ''', ) return parser ################################################################################ # Helper Functions ################################################################################ # ============================================================================== def parse_file(): """ Validates keys in JSON file and updates CLI input context Initializes a MultiDiGraph object using input data model_graph Initializes a matrix of emission probabilities emission_matrix :return: model_graph, emission_matrix """ # Load the JSON input file, validate keys file_data = json.load(context['file']) for key in file_data: if key == "transition_matrix" or key == "emission_matrix": continue assert key in context # Update the CLI context with JSON input context.update(file_data) model_graph = nx.MultiDiGraph(build_graph(np.array(file_data['transition_matrix']))) emission_matrix = np.array(file_data['emission_matrix']) return model_graph, emission_matrix def random_emission(): """ Initialize an emission matrix size SxE Where S is number of states and E is number of emissions :return: Initialized emission_matrix """ emission_matrix = np.zeros((context["nodes"], len(set(context["sequence"])))) shape = emission_matrix.shape for row in range(0, shape[0]): for col in range(0, shape[1]): # Let random number swing below 0 to increase chance of nodes not emitting emit_prob = round(random.uniform(-0.25, 1.0), 2) emit_prob = 0.0 if emit_prob < 0.0 else emit_prob emission_matrix[row][col] = emit_prob return emission_matrix def random_graph(nodes, edges=2): """ Create a random graph represented as a list [(from_node, to_node, {'weight': edge_weight}), ...] Networkx can use this list in constructors for graph objects :param nodes: The number of nodes in the graph :param edges: The number of edges connecting nodes in the graph :return: A list [(from_node, to_node, {'weight': edge_weight}), ...] """ # By default, make twice as many edges as there are nodes edges *= nodes if edges == 2 else 1 r_graph = [] for x in range(0, edges): while True: new_edge = ( random.randint(0, nodes - 1), # Randomly select a from_node index random.randint(0, nodes - 1), # Randomly select a to_node index { # Randomly set an edge weight between from_node and to_node 'weight': round(random.uniform(0.0, 1.0), 2) } ) if not any((new_edge[0], new_edge[1]) == (a, b) for a, b, w in r_graph): break r_graph.append(new_edge) return r_graph def build_graph(t_matrix): """ Converts a transition matrix to a list of edges and weights This list can then be passed to NetworkX graph constructors :param t_matrix: The transition matrix to build the graph from :return: A list [(from_node, to_node, {'weight': edge_weight}), ...] """ n_graph = [] shape = t_matrix.shape for row in range(0, shape[0]): for col in range(0, shape[1]): if t_matrix[row][col] <= 0.0: continue new_edge = (row, col, {'weight': t_matrix[row][col]}) n_graph.append(new_edge) return n_graph def transition_matrix(graph: nx.MultiDiGraph): """ Build a transition matrix from a Networkx MultiDiGraph object :param graph: An initialized MultiDiGraph graph object :return: An initialized transition matrix with shape (NODE_COUNT, NODE_COUNT) """ # Initialize a matrix of zeros with size ExE where E is total number of states (nodes) t_matrix = np.zeros((context["nodes"], context["nodes"])) # Build matrices from iterating over the graph for a, b, weight in graph.edges(data='weight'): t_matrix[a][b] = weight if context["show_all"]: print(f'{a}->{b}: {weight}') return t_matrix def make_emission_dict(emission_matrix): """ Create a dictionary that maps to index keys for each emission. emission_keys Create a dictionary that maps to a list of emitting nodes for each emission. emission_dict :param emission_matrix: An emission_matrix size NxE Where N is the number of nodes (states) and E is the number of emissions :return: emission_dict, emission_keys """ emission_dict = {} for emission in sorted(set(context["sequence"])): emission_dict[emission] = [] emission_keys = dict.fromkeys(emission_dict.keys()) # Initialize emission_dict to store a list of all nodes that emit the key value shape = emission_matrix.shape i = 0 for key in emission_dict.keys(): for row in range(0, shape[0]): if emission_matrix[row][i] > 0: emission_dict[key].append(row) emission_keys[key] = i i += 1 return emission_dict, emission_keys def int_input(prompt): """ Forces integer input. Retries and warns if bogus values are entered. :param prompt: The initial prompt message to show for input :return: The integer input by the user at runtime """ while True: try: value = int(input(prompt)) break except ValueError: print("Please enter an integer value") return value def triple_input(matrix): """ Takes 3 integer input, validates it makes sense for the selected matrix If row or column selected is outside the limits of the matrix, warn and retry input until valid :param matrix: The matrix to use for input validation :return: The validated input """ row = int_input("Row: ") col = int_input("Col: ") value = int_input("Value: ") row, col = check_input(row, col, matrix) return row, col, value def check_input(row, col, matrix): """ Checks that row, col input values are within the bounds of matrix If valid values are passed initially, no additional prompts are made. Retries input until valid values are input. :param row: The row index input by the user :param col: The col index input by the user :param matrix: The matrix to use for input validation :return: The validated input for row and column index """ while row > matrix.shape[0] - 1: print(f'{row} is too large for transition matrix of shape {matrix.shape}') row = int_input("Row : ") while col > matrix.shape[1] - 1: print(f'{col} is too large for transition matrix of shape {matrix.shape}') col = int_input("Col: ") return row, col ################################################################################ # Hidden Markov Model ################################################################################ # ============================================================================== def find_paths(emission_dict, t_matrix): """ Find all possible paths for an emission sequence :param emission_dict: A dictionary of emitters for emissions {emission_1: [0, 1], emission_2: [1, 3], ...} :param t_matrix: A transition matrix size NxN where N is the total number of nodes in the graph :return: A list of validated paths for the emission given through the CLI """ paths = [] for emission in context["sequence"]: paths.append(emission_dict[emission]) # Take the cartesian product of the emitting nodes to get a list of all possible paths # + Return only the paths which have > 0 probability given the transition matrix return validate_paths(list(itertools.product(*paths)), t_matrix) def validate_paths(path_list: list, t_matrix): """ Checks all paths in path_list [[0, 1, 2, 3], [0, 1, 1, 2], ...] If the transition matrix t_matrix indicates any node in a path can't reach the next node in path The path can't happen given our graph. Remove it from the list of paths. :param path_list: A list of paths to validate :param t_matrix: A transition matrix size NxN where N is the total number of nodes in the graph :return: A list of validated paths [[0, 1, 2, 3], [0, 1, 1, 2], ...] """ valid_paths = [] for path in path_list: valid = True for step in range(0, len(path) - 1): current_node = path[step] # If the transition matrix indicates that the chance to move to next step in path is 0 if t_matrix[current_node][path[step+1]] <= 0.0: # The path cannot possibly happen. Don't consider it. valid = False break if valid: # We reached the end of our path without hitting a dead-end. The path is valid. valid_paths.append(path) return valid_paths def find_probability(emission_matrix, t_matrix, emission_keys, valid_paths): """ Find probability of paths occurring given our current HMM Store result in a dictionary {probability: (0, 1, 2, 3), probability_2: (0, 0, 1, 2)} :param emission_matrix: A matrix of emission probabilities NxE where N is the emitting node and E is the emission :param t_matrix: A transition matrix NxN where N is the total number of nodes in the graph :param emission_keys: A dictionary mapping to index values for emissions as E in the emission_matrix :param valid_paths: A list of valid paths to calculate probability given an emission sequence :return: A dictionary of {prob: path}; For example {probability: (0, 1, 2, 3), probability_2: (0, 0, 1, 2)} """ path_prob = {} seq = list(context["sequence"]) for path in valid_paths: calculations = f'Calculating {path}: ' prob = 1.0 for step in range(0, len(path) - 1): current_node = path[step] next_node = path[step + 1] emission_index = emission_keys[seq[step]] emission_prob = emission_matrix[current_node][emission_index] transition_prob = t_matrix[current_node][next_node] calculations += f'({emission_prob:.2f} * {transition_prob:.2f}) * ' prob *= emission_prob * transition_prob emission_index = emission_keys[seq[step + 1]] final_emission_prob = emission_matrix[next_node][emission_index] prob *= final_emission_prob calculations += f'{final_emission_prob:.2f} = {prob:.6f}' if prob > 0.0: # Don't keep paths which aren't possible due to emission sequence path_prob[prob] = path if context["show_all"]: print(calculations) return path_prob def run_problem(transition_matrix, emission_matrix): """ Runs the HMM calculations given a transition_matrix and emission_matrix :param transition_matrix: A matrix size NxN where N is the total number of nodes and values represent probability :param emission_matrix: A matrix size NxE where N is total nodes and E is total number of emissions :return: A dictionary of {probability: path} sorted by probability key from in descending order """ # Dictionary of {emission: [emitter, ...]} emission_dict, emission_keys = make_emission_dict(emission_matrix) valid_paths = find_paths(emission_dict, transition_matrix) path_prob = find_probability(emission_matrix, transition_matrix, emission_keys, valid_paths) result = {key: path_prob[key] for key in dict.fromkeys(sorted(path_prob.keys(), reverse=True))} print(f'Finding most probable path for given observation sequence: {context["sequence"]}\n' f'\tTotal nodes in graph: {context["nodes"]}\n' f'\tTotal edges in graph: {context["edges"]}\n' f'\tNumber of sequences: {len(set(context["sequence"]))}\n' f'\tInteractive mode: {context["interactive"]}\n' f'\tEmitting nodes: {emission_dict}\n' f'Transition matrix: \n{transition_matrix}\n' f'Emission matrix: \n{emission_matrix}' ) return result def show_result(result): """ Prints results from running the HMM calculations :param result: The result dictionary returned by run_problem() """ if len(result) == 0: print(f'No valid paths found for sequence {context["sequence"]}') elif context["show_all"]: print(f'Final paths sorted by probability:') [print(f'{path} has probability:\t {prob:.6f}') for prob, path in result.items()] else: print(f'{list(result.values())[0]} has the highest probability of {list(result.keys())[0]}') def draw_graph(graph): """ Draws the model_graph for the current HMM using NetworkX :param graph: An initialized MultiDiGraph object with edge weights representing transition probability """ # Get a dictionary of {node: position} for drawing the graph dict_pos = nx.spring_layout(graph) nx.draw( graph, dict_pos, with_labels=True, node_size=[x * 200 for x in dict(graph.degree).values()], alpha=1, arrowstyle="->", arrowsize=25, ) # TODO: Fix parallel path weight display nx.draw_networkx_edge_labels(graph, dict_pos) plt.show() ################################################################################ # Main ################################################################################ # ============================================================================== def main(args: List[str]): parser = init_parser() global context context = vars(parser.parse_args(args[1:])) if context["file"]: # If a file was provided, use that data instead model_graph, emission_matrix = parse_file() else: # If no file was provided, build a random graph with the requested number of nodes and edges model_graph = nx.MultiDiGraph(random_graph(context["nodes"], context["edges"])) # Create a random emission matrix emission_matrix = random_emission() t_matrix = transition_matrix(model_graph) result = run_problem(t_matrix, emission_matrix) show_result(result) # Draw the graph for a visual example to go with output if not context["silent"]: draw_graph(model_graph) # Unless we are in interactive mode, we're finished. Return. if not context["interactive"]: return # Prompt to update the transition or emission matrix, then rerun problem with new values print("Choose matrix to update:\n\t1. Transition\n\t2. Emission\n\t3. Both", end='') choice = input() if choice == '1': row, col, value = triple_input(t_matrix) t_matrix[row][col] = value elif choice == '2': row, col, value = triple_input(emission_matrix) emission_matrix[row][col] = value elif choice == '3': print('\nInput for updating transition matrix') row, col, value = triple_input(t_matrix) t_matrix[row][col] = value print('\nInput for updating emission matrix') row, col, value = triple_input(emission_matrix) emission_matrix[row][col] = value result = run_problem(t_matrix, emission_matrix) show_result(result) # Draw the graph for a visual example to go with output if not context["silent"]: model_graph = nx.MultiDiGraph(build_graph(np.array(t_matrix))) draw_graph(model_graph) if __name__ == "__main__": sys.exit(main(sys.argv))