# This file is part of COAT, and is distributed under the # OSI-approved BSD 3-Clause License. See top-level LICENSE file or # https://github.com/Kitware/COAT/blob/master/LICENSE for details. import random import numpy as np zero_threshold = 0.00000001 class KMNode(object): def __init__(self, id, exception=0, match=None, visit=False): self.id = id self.exception = exception self.match = match self.visit = visit class KuhnMunkres(object): def __init__(self): self.matrix = None self.x_nodes = [] self.y_nodes = [] self.minz = float("inf") self.x_length = 0 self.y_length = 0 self.index_x = 0 self.index_y = 1 def __del__(self): pass def set_matrix(self, x_y_values): xs = set() ys = set() for x, y, value in x_y_values: xs.add(x) ys.add(y) if len(xs) < len(ys): self.index_x = 0 self.index_y = 1 else: self.index_x = 1 self.index_y = 0 xs, ys = ys, xs x_dic = {x: i for i, x in enumerate(xs)} y_dic = {y: j for j, y in enumerate(ys)} self.x_nodes = [KMNode(x) for x in xs] self.y_nodes = [KMNode(y) for y in ys] self.x_length = len(xs) self.y_length = len(ys) self.matrix = np.zeros((self.x_length, self.y_length)) for row in x_y_values: x = row[self.index_x] y = row[self.index_y] value = row[2] x_index = x_dic[x] y_index = y_dic[y] self.matrix[x_index, y_index] = value for i in range(self.x_length): self.x_nodes[i].exception = max(self.matrix[i, :]) def km(self): for i in range(self.x_length): while True: self.minz = float("inf") self.set_false(self.x_nodes) self.set_false(self.y_nodes) if self.dfs(i): break self.change_exception(self.x_nodes, -self.minz) self.change_exception(self.y_nodes, self.minz) def dfs(self, i): x_node = self.x_nodes[i] x_node.visit = True for j in range(self.y_length): y_node = self.y_nodes[j] if not y_node.visit: t = x_node.exception + y_node.exception - self.matrix[i][j] if abs(t) < zero_threshold: y_node.visit = True if y_node.match is None or self.dfs(y_node.match): x_node.match = j y_node.match = i return True else: if t >= zero_threshold: self.minz = min(self.minz, t) return False def set_false(self, nodes): for node in nodes: node.visit = False def change_exception(self, nodes, change): for node in nodes: if node.visit: node.exception += change def get_connect_result(self): ret = [] for i in range(self.x_length): x_node = self.x_nodes[i] j = x_node.match y_node = self.y_nodes[j] x_id = x_node.id y_id = y_node.id value = self.matrix[i][j] if self.index_x == 1 and self.index_y == 0: x_id, y_id = y_id, x_id ret.append((x_id, y_id, value)) return ret def get_max_value_result(self): ret = -100 for i in range(self.x_length): j = self.x_nodes[i].match ret = max(ret, self.matrix[i][j]) return ret def run_kuhn_munkres(x_y_values): process = KuhnMunkres() process.set_matrix(x_y_values) process.km() return process.get_connect_result(), process.get_max_value_result() def test(): values = [] random.seed(0) for i in range(500): for j in range(1000): value = random.random() values.append((i, j, value)) return run_kuhn_munkres(values) if __name__ == "__main__": values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)] print(run_kuhn_munkres(values))