151 lines
4.2 KiB
Python
151 lines
4.2 KiB
Python
# This file is part of COAT, and is distributed under the
|
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
|
|
|
import random
|
|
import numpy as np
|
|
|
|
zero_threshold = 0.00000001
|
|
|
|
class KMNode(object):
|
|
def __init__(self, id, exception=0, match=None, visit=False):
|
|
self.id = id
|
|
self.exception = exception
|
|
self.match = match
|
|
self.visit = visit
|
|
|
|
|
|
class KuhnMunkres(object):
|
|
def __init__(self):
|
|
self.matrix = None
|
|
self.x_nodes = []
|
|
self.y_nodes = []
|
|
self.minz = float("inf")
|
|
self.x_length = 0
|
|
self.y_length = 0
|
|
self.index_x = 0
|
|
self.index_y = 1
|
|
|
|
def __del__(self):
|
|
pass
|
|
|
|
def set_matrix(self, x_y_values):
|
|
xs = set()
|
|
ys = set()
|
|
for x, y, value in x_y_values:
|
|
xs.add(x)
|
|
ys.add(y)
|
|
|
|
if len(xs) < len(ys):
|
|
self.index_x = 0
|
|
self.index_y = 1
|
|
else:
|
|
self.index_x = 1
|
|
self.index_y = 0
|
|
xs, ys = ys, xs
|
|
|
|
x_dic = {x: i for i, x in enumerate(xs)}
|
|
y_dic = {y: j for j, y in enumerate(ys)}
|
|
self.x_nodes = [KMNode(x) for x in xs]
|
|
self.y_nodes = [KMNode(y) for y in ys]
|
|
self.x_length = len(xs)
|
|
self.y_length = len(ys)
|
|
|
|
self.matrix = np.zeros((self.x_length, self.y_length))
|
|
for row in x_y_values:
|
|
x = row[self.index_x]
|
|
y = row[self.index_y]
|
|
value = row[2]
|
|
x_index = x_dic[x]
|
|
y_index = y_dic[y]
|
|
self.matrix[x_index, y_index] = value
|
|
|
|
for i in range(self.x_length):
|
|
self.x_nodes[i].exception = max(self.matrix[i, :])
|
|
|
|
def km(self):
|
|
for i in range(self.x_length):
|
|
while True:
|
|
self.minz = float("inf")
|
|
self.set_false(self.x_nodes)
|
|
self.set_false(self.y_nodes)
|
|
|
|
if self.dfs(i):
|
|
break
|
|
|
|
self.change_exception(self.x_nodes, -self.minz)
|
|
self.change_exception(self.y_nodes, self.minz)
|
|
|
|
def dfs(self, i):
|
|
x_node = self.x_nodes[i]
|
|
x_node.visit = True
|
|
for j in range(self.y_length):
|
|
y_node = self.y_nodes[j]
|
|
if not y_node.visit:
|
|
t = x_node.exception + y_node.exception - self.matrix[i][j]
|
|
if abs(t) < zero_threshold:
|
|
y_node.visit = True
|
|
if y_node.match is None or self.dfs(y_node.match):
|
|
x_node.match = j
|
|
y_node.match = i
|
|
return True
|
|
else:
|
|
if t >= zero_threshold:
|
|
self.minz = min(self.minz, t)
|
|
return False
|
|
|
|
def set_false(self, nodes):
|
|
for node in nodes:
|
|
node.visit = False
|
|
|
|
def change_exception(self, nodes, change):
|
|
for node in nodes:
|
|
if node.visit:
|
|
node.exception += change
|
|
|
|
def get_connect_result(self):
|
|
ret = []
|
|
for i in range(self.x_length):
|
|
x_node = self.x_nodes[i]
|
|
j = x_node.match
|
|
y_node = self.y_nodes[j]
|
|
x_id = x_node.id
|
|
y_id = y_node.id
|
|
value = self.matrix[i][j]
|
|
|
|
if self.index_x == 1 and self.index_y == 0:
|
|
x_id, y_id = y_id, x_id
|
|
ret.append((x_id, y_id, value))
|
|
|
|
return ret
|
|
|
|
def get_max_value_result(self):
|
|
ret = -100
|
|
for i in range(self.x_length):
|
|
j = self.x_nodes[i].match
|
|
ret = max(ret, self.matrix[i][j])
|
|
return ret
|
|
|
|
|
|
def run_kuhn_munkres(x_y_values):
|
|
process = KuhnMunkres()
|
|
process.set_matrix(x_y_values)
|
|
process.km()
|
|
return process.get_connect_result(), process.get_max_value_result()
|
|
|
|
|
|
def test():
|
|
values = []
|
|
random.seed(0)
|
|
for i in range(500):
|
|
for j in range(1000):
|
|
value = random.random()
|
|
values.append((i, j, value))
|
|
|
|
return run_kuhn_munkres(values)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)]
|
|
print(run_kuhn_munkres(values))
|