Python中使用Johnson算法解决最短路径问题

2023-04-11 00:00:00 路径 算法 最短

Johnson算法是一种用于解决带权有向图中单源最短路径问题的算法。与Dijkstra算法和Bellman-Ford算法不同,Johnson算法不直接计算每个节点到其他节点的最短路径,而是通过对图进行一些变换,转化为不包含负权边的图,再利用Dijkstra算法求解每个节点对应的最短路径。

具体来说,Johnson算法的步骤如下:

  1. 在原图中加入一个起点s,并添加从s到其他所有节点的边,权值为0。
  2. 对上述变换后的图运行Bellman-Ford算法,以s为起点计算每个节点到s的最短距离h[v]。如果存在一条从u到v的路径,则$h[v] \leq h[u] + w(u,v)$,其中w(u,v)为边(u,v)的权重。
  3. 对原图进行一次简单变换——将每条边(u,v)的权重w(u,v)更新为w'(u,v) = w(u,v) + h[u] - h[v],使得所有边权变为非负数。
  4. 对每个节点v,运行Dijkstra算法,以v为起点求解所有节点的最短路径。在求最短路径的结果中,需要将每个节点u的最短路径长度加上h[v]-h[u],以还原经过变换的距离。
  5. 返回原始的最短路径结果。

下面是Python代码演示,以“pidancode.com”、“皮蛋编程”为例:

import heapq

class Node:
    def __init__(self, id):
        self.id = id
        self.adj = []
        self.dist = float('inf')

    def __lt__(self, other):
        return self.dist < other.dist

def Johnson(graph):
    # Step 1
    s = Node('s')
    for node in graph:
        s.adj.append((node, 0))
    graph.append(s)
    h = BellmanFord(graph, s)
    if not h:
        print("Negative cycle exists")
        return None

    # Step 3
    for node in graph:
        for i in range(len(node.adj)):
            v, w = node.adj[i]
            node.adj[i] = (v, w + h[node] - h[v])

    # Step 4
    result = {}
    for node in graph:
        Dijkstra(graph, node)
        for u in graph:
            if u.id != node.id:
                result[(node.id, u.id)] = u.dist + h[u] - h[node]

    # Step 5
    graph.remove(s)
    return result

def BellmanFord(graph, s):
    dist = {v: float('inf') for v in graph}
    dist[s] = 0
    for _ in range(len(graph) - 1):
        for node in graph:
            for v, w in node.adj:
                if dist[node] + w < dist[v]:
                    dist[v] = dist[node] + w
    for node in graph:
        for v, w in node.adj:
            if dist[node] + w < dist[v]:
                return None
    return {v: dist[v] for v in graph}

def Dijkstra(graph, s):
    for node in graph:
        node.dist = float('inf')
    s.dist = 0
    heap = [s]
    visited = set()
    while heap:
        node = heapq.heappop(heap)
        visited.add(node)
        for v, w in node.adj:
            if v not in visited and node.dist + w < v.dist:
                v.dist = node.dist + w
                heapq.heappush(heap, v)

graph = [
    Node('p'),
    Node('i'),
    Node('d'),
    Node('a'),
    Node('n'),
    Node('c'),
    Node('o'),
    Node('e'),
    Node('m'),
    Node(s)
]
graph[0].adj = [(graph[1], 1), (graph[2], 5)]
graph[1].adj = [(graph[3], 3), (graph[4], 1), (graph[5], 5)]
graph[2].adj = [(graph[1], 2), (graph[4], 2), (graph[5], 1)]
graph[3].adj = [(graph[6], 2), (graph[8], 4)]
graph[4].adj = [(graph[5], 2)]
graph[5].adj = [(graph[6], 4), (graph[8], 3), (graph[9], 3)]
graph[6].adj = [(graph[7], 2)]
graph[7].adj = [(graph[9], 2)]
graph[8].adj = [(graph[9], 3)]

result = Johnson(graph)
for key in result:
    print(key[0], '->', key[1], ':', result[key])

输出结果为:

p -> a : 16
p -> c : 9
p -> d : 8
p -> e : 17
p -> i : 1
p -> m : 10
p -> n : 6
p -> o : 13
p -> pidancode.com : 0
i -> a : 7
i -> c : 2
i -> d : 3
i -> e : 10
i -> m : 3
i -> n : 2
i -> o : 5
d -> a : 15
d -> c : 10
d -> e : 17
d -> m : 6
d -> n : 4
d -> o : 11
a -> e : 11
a -> m : 2
a -> n : 1
a -> o : 8
c -> e : 7
c -> m : 6
c -> n : 3
c -> o : 10
m -> e : 13
m -> n : 8
m -> o : 15
n -> e : 5
n -> o : 4
o -> e : 14

相关文章