Dijkstra 算法模板及应用
Info
已完成网站教程、网站习题、配套插件中所有多语言代码的校准,解决了之前 chatGPT 翻译可能出错的问题~
读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:
LeetCode | Difficulty |
---|---|
1514. Path with Maximum Probability | 🟠 |
1631. Path With Minimum Effort | 🟠 |
743. Network Delay Time | 🟠 |
Prerequisites
Before reading this article, you should first learn:
Actually, the underlying principles of many algorithms are exceptionally simple. They just extend step by step, making them look particularly complex and impressive.
However, if you have read historical articles, you should be able to form your own understanding of algorithms. You will find that many algorithms are just old concepts in new packaging, with no real innovation, making them very monotonous.
For example, as mentioned in Dong Ge's Guide to Binary Trees (Summary), binary trees are very important. Once you master this structure, you will realize that Dynamic Programming, Divide and Conquer Algorithms, Backtracking (DFS) Algorithms, BFS Algorithm Framework, Union-Find Algorithm, and Binary Heap Implementing Priority Queue are just different applications of binary trees.
In this article, I will tell you that Dijkstra's Algorithm (often transliterated as Dijkstra Algorithm) is merely an enhanced version of the BFS algorithm, both derived from the level-order traversal of binary trees.
Next, we will discuss Dijkstra's Algorithm in depth, starting from the level-order traversal of binary trees, providing the code framework for Dijkstra's Algorithm, and effortlessly solving some problems using Dijkstra's Algorithm.
图的抽象
前文 图论第一期:遍历基础 说过「图」这种数据结构的基本实现,图中的节点一般就抽象成一个数字(索引),图的具体实现一般是「邻接矩阵」或者「邻接表」。
比如上图这幅图用邻接表和邻接矩阵的存储方式如下:
前文 图论第二期:拓扑排序 告诉你,我们用邻接表的场景更多,结合上图,一幅图可以用如下 Java 代码表示:
// graph[s] stores the nodes that node s points to (out-degree)
List<Integer>[] graph;
// vector<type> name means defining a dynamic array name with element type type
// graph[s] stores the nodes pointed to by node s (out-degree)
vector<int>[] graph;
# graph[s] stores the nodes that node s points to (out-degree)
graph: List[List[int]] = []
// graph[s] stores the nodes that node s points to (out-degree)
var graph [][]int
// graph[s] stores the nodes that node s points to (outdegree)
var graph;
如果你想把一个问题抽象成「图」的问题,那么首先要实现一个 API adj
:
// input node s and return its adjacent nodes
List<Integer> adj(int s);
// input node s and return its adjacent nodes
vector<int> adj(int s);
# input node s and return the adjacent nodes of s
def adj(s: int) -> List[int]:
// input node s, return the adjacent nodes of s
func adj(s int) []int {
}
// input node s and return its adjacent nodes
var adj = function(s) {
var adjList = [];
// implementation
return adjList;
};
类似多叉树节点中的 children
字段记录当前节点的所有子节点,adj(s)
就是计算一个节点 s
的相邻节点。
比如上面说的用邻接表表示「图」的方式,adj
函数就可以这样表示:
List<Integer>[] graph;
// input node s, return the adjacent nodes of s
List<Integer> adj(int s) {
return graph[s];
}
vector<int>* graph;
// input node s, return the adjacent nodes of s
vector<int> adj(int s) {
return graph[s];
}
from typing import List
graph: List[List[int]] = []
# input node s, return the adjacent nodes of s
def adj(s: int) -> List[int]:
return graph[s]
var graph []List[int]
// input node s, return the adjacent nodes of s
func adj(s int) []int {
return graph[s]
}
var graph;
// input node s, return the adjacent nodes of s
function adj(s) {
return graph[s];
}
当然,对于「加权图」,我们需要知道两个节点之间的边权重是多少,所以还可以抽象出一个 weight
方法:
// return the weight of the edge from node 'from' to node 'to'
int weight(int from, int to);
// return the weight of the edge from node 'from' to node 'to'
int weight(int from, int to);
# return the weight of the edge from node 'from' to node 'to'
def weight(from_node: int, to_node: int) -> int:
// weight returns the weight of the edge between node from and node to
func weight(from, to int) int {}
var weight = function (from, to) {
// return the weight of the edge between node from and node to
};
这个 weight
方法可以根据实际情况而定,因为不同的算法题,题目给的「权重」含义可能不一样,我们存储权重的方式也不一样。
有了上述基础知识,就可以搞定 Dijkstra 算法了,下面我给你从二叉树的层序遍历开始推演出 Dijkstra 算法的实现。
二叉树层级遍历和 BFS 算法
我们之前说过二叉树的层级遍历框架:
// Input the root node of a binary tree, level-order traverse this binary tree
void levelTraverse(TreeNode root) {
if (root == null) return 0;
Queue<TreeNode> q = new LinkedList<>();
q.offer(root);
int depth = 1;
// Traverse each level of the binary tree from top to bottom
while (!q.isEmpty()) {
int sz = q.size();
// Traverse each node of each level from left to right
for (int i = 0; i < sz; i++) {
TreeNode cur = q.poll();
printf("节点 %s 在第 %s 层", cur, depth);
// Put the nodes of the next level into the queue
if (cur.left != null) {
q.offer(cur.left);
}
if (cur.right != null) {
q.offer(cur.right);
}
}
depth++;
}
}
// input the root node of a binary tree, perform level-order traversal of this binary tree
int levelTraverse(TreeNode* root) {
if (root == nullptr) return 0;
queue<TreeNode*> q;
q.push(root);
int depth = 1;
// traverse each level of the binary tree from top to bottom
while (!q.empty()) {
int sz = q.size();
// traverse each node of each level from left to right
for (int i = 0; i < sz; i++) {
TreeNode* cur = q.front();
q.pop();
printf("节点 %s 在第 %s 层", cur, depth);
// put the nodes of the next level into the queue
if (cur->left != nullptr) {
q.push(cur->left);
}
if (cur->right != nullptr) {
q.push(cur->right);
}
}
depth++;
}
return depth;
}
# Input the root node of a binary tree, level-order traverse this binary tree
def levelTraverse(root: TreeNode):
if root == None:
return 0
q = []
q.append(root)
depth = 1
# Traverse each level of the binary tree from top to bottom
while len(q) > 0:
sz = len(q)
# Traverse each node of each level from left to right
for i in range(sz):
cur = q.pop(0)
printf("节点 %s 在第 %s 层", cur, depth)
# Put the nodes of the next level into the queue
if cur.left != None:
q.append(cur.left)
if cur.right != None:
q.append(cur.right)
depth += 1
// Input the root node of a binary tree, level order traverse this binary tree
func levelTraverse(root *TreeNode) {
if root == nil {
return 0
}
q := make([]*TreeNode, 0)
q = append(q, root)
depth := 1
// Traverse each level of the binary tree from top to bottom
for len(q) != 0 {
sz := len(q)
// Traverse each node of each level from left to right
for i := 0; i < sz; i++ {
cur := q[i]
fmt.Printf("节点 %v 在第 %v 层\n", cur, depth)
// Put the next level nodes into the queue
if cur.left != nil {
q = append(q, cur.left)
}
if cur.right != nil {
q = append(q, cur.right)
}
}
q = q[sz:]
depth++
}
}
var levelTraverse = function(root) {
if (!root) return 0;
let q = [];
q.push(root);
let depth = 1;
// traverse each level of the binary tree from top to bottom
while (q.length !== 0) {
let sz = q.length;
// traverse each node of each level from left to right
for (let i = 0; i < sz; i++) {
let cur = q.shift();
console.log(`节点 ${cur} 在第 ${depth} 层`);
// push the nodes of the next level into the queue
if (cur.left !== null) {
q.push(cur.left);
}
if (cur.right !== null) {
q.push(cur.right);
}
}
depth++;
}
};
我们先来思考一个问题,注意二叉树的层级遍历 while
循环里面还套了个 for
循环,为什么要这样?
while
循环和 for
循环的配合正是这个遍历框架设计的巧妙之处:
while
循环控制一层一层往下走,for
循环利用 sz
变量控制从左到右遍历每一层二叉树节点。
注意我们代码框架中的 depth
变量,其实就记录了当前遍历到的层数。换句话说,每当我们遍历到一个节点 cur
,都知道这个节点属于第几层。
算法题经常会问二叉树的最大深度呀,最小深度呀,层序遍历结果呀,等等问题,所以记录下来这个深度 depth
是有必要的。
基于二叉树的遍历框架,我们又可以扩展出多叉树的层序遍历框架:
// Input the root node of a multi-way tree, level order traverse this multi-way tree
void levelTraverse(TreeNode root) {
if (root == null) return;
Queue<TreeNode> q = new LinkedList<>();
q.offer(root);
int depth = 1;
// Traverse each level of the multi-way tree from top to bottom
while (!q.isEmpty()) {
int sz = q.size();
// Traverse each node of each level from left to right
for (int i = 0; i < sz; i++) {
TreeNode cur = q.poll();
printf("节点 %s 在第 %s 层", cur, depth);
// Put the next level nodes into the queue
for (TreeNode child : cur.children) {
q.offer(child);
}
}
depth++;
}
}
void levelTraverse(TreeNode* root) {
if (root == nullptr) return;
queue<TreeNode*> q;
q.push(root);
int depth = 1;
while (!q.empty()) {
int sz = q.size();
for (int i = 0; i < sz; i++) {
TreeNode* cur = q.front();
q.pop();
printf("节点 %s 在第 %s 层", cur, depth);
for (auto child : cur->children) {
q.push(child);
}
}
depth++;
}
}
from typing import Optional
from collections import deque
class TreeNode:
def __init__(self, val: Optional[int] = None, children: Optional[List['TreeNode']] = None):
self.val = val
self.children = children or []
# Input the root node of a multi-branch tree and perform level order traversal on this tree
def levelTraverse(root: TreeNode) -> None:
if not root:
return
q = deque([root])
depth = 1
# Traverse each level of the multi-branch tree from top to bottom
while q:
sz = len(q)
# Traverse each node of each level from left to right
for i in range(sz):
cur = q.popleft()
print("节点 {} 在第 {} 层".format(cur.val, depth))
# Put the next level nodes into the queue
for child in cur.children:
q.append(child)
depth += 1
// Input the root node of a multi-way tree and perform level-order traversal on it
func levelTraverse(root *TreeNode) {
if root == nil {
return
}
q := make([]*TreeNode, 0)
q = append(q, root)
depth := 1
// Traverse each level of the multi-way tree from top to bottom
for len(q) > 0 {
sz := len(q)
// Traverse each node of each level from left to right
for i := 0; i < sz; i++ {
cur := q[i]
fmt.Printf("节点 %s 在第 %d 层", cur, depth)
// Enqueue the nodes of the next level
for _, child := range cur.children {
q = append(q, child)
}
}
depth++
q = q[sz:]
}
}
var levelTraverse = function(root) {
if (root == null) return;
var q = [];
q.push(root);
var depth = 1;
while (q.length > 0) {
var sz = q.length;
for (var i = 0; i < sz; i++) {
var cur = q.shift();
console.log("节点 " + cur + " 在第 " + depth + " 层");
for (var j = 0; j < cur.children.length; j++) {
q.push(cur.children[j]);
}
}
depth++;
}
};
基于多叉树的遍历框架,我们又可以扩展出 BFS(广度优先搜索)的算法框架:
// input the starting point to perform BFS search
int BFS(Node start) {
// core data structure
Queue<Node> q;
// avoid walking back
Set<Node> visited;
// add the starting point to the queue
q.offer(start);
visited.add(start);
// record the number of steps in the search
int step = 0;
while (q not empty) {
int sz = q.size();
// spread all nodes in the current queue by one step around
for (int i = 0; i < sz; i++) {
Node cur = q.poll();
printf("从 %s 到 %s 的最短距离是 %s", start, cur, step);
// add adjacent nodes of cur to the queue
for (Node x : cur.adj()) {
if (x not in visited) {
q.offer(x);
visited.add(x);
}
}
}
step++;
}
}
#include <queue>
#include <set>
// Input the starting point and perform BFS search
int BFS(Node* start) {
// Core data structure
std::queue<Node*> q;
// Avoid backtracking
std::set<Node*> visited;
// Add the starting point to the queue
q.push(start);
visited.insert(start);
// Record the number of steps in the search
int step = 0;
while (!q.empty()) {
int sz = q.size();
// Spread all nodes in the current queue one step in all directions
for (int i = 0; i < sz; i++) {
Node* cur = q.front();
q.pop();
printf("从 %s 到 %s 的最短距离是 %s", start->val, cur->val, step);
// Add the adjacent nodes of cur to the queue
for (Node* x : cur->adj()) {
if (visited.count(x) == 0) {
q.push(x);
visited.insert(x);
}
}
}
step++;
}
}
# input starting point, perform BFS search
def BFS(start: Node) -> int:
# core data structure
q = []
# avoid backtracking
visited = set()
# add the starting point to the queue
q.append(start)
visited.add(start)
# record the number of steps in the search
step = 0
while len(q) != 0:
sz = len(q)
# expand all nodes in the current queue by one step
for i in range(sz):
cur = q.pop(0)
print(f"从 {start} 到 {cur} 的最短距离是 {step}")
# add neighboring nodes of cur to the queue
for x in cur.adj():
if x not in visited:
q.append(x)
visited.add(x)
step += 1
func BFS(start Node) int {
// core data structure
q := make([]Node, 0)
// avoid walking back
visited := make(map[Node]bool)
// add the starting point to the queue
q = append(q, start)
visited[start] = true
// record the number of search steps
step := 0
for len(q) != 0 {
sz := len(q)
// spread all nodes in the current queue one step outward
for i := 0; i < sz; i++ {
cur := q[0]
q = q[1:]
fmt.Printf("从 %s 到 %s 的最短距离是 %d", start, cur, step)
// add the adjacent nodes of cur to the queue
adj := cur.adj()
for _, x := range adj {
if _, ok := visited[x]; !ok {
q = append(q, x)
visited[x] = true
}
}
}
step++
}
return step
}
var BFS = function(start) {
// core data structure
var q = [];
// avoid stepping back
var visited = new Set();
// add the start point to the queue
q.push(start);
visited.add(start);
// record the number of steps in the search
var step = 0;
while (q.length !== 0) {
var sz = q.length;
// spread all nodes in the current queue one step in all directions
for (var i = 0; i < sz; i++) {
var cur = q.shift();
console.log("从 " + start + " 到 " + cur + " 的最短距离是 " + step);
// add the adjacent nodes of cur to the queue
for (var j = 0; j < cur.adj().length; j++) {
var x = cur.adj()[j];
if (!visited.has(x)) {
q.push(x);
visited.add(x);
}
}
}
step++;
}
};
如果对 BFS 算法不熟悉,可以看前文 BFS 算法框架,这里只是为了让你做个对比,所谓 BFS 算法,就是把算法问题抽象成一幅「无权图」,然后继续玩二叉树层级遍历那一套罢了。
注意,我们的 BFS 算法框架也是 while
循环嵌套 for
循环的形式,也用了一个 step
变量记录 for
循环执行的次数,无非就是多用了一个 visited
集合记录走过的节点,防止走回头路罢了。
为什么这样呢?
所谓「无权图」,与其说每条「边」没有权重,不如说每条「边」的权重都是 1,从起点 start
到任意一个节点之间的路径权重就是它们之间「边」的条数,那可不就是 step
变量记录的值么?
再加上 BFS 算法利用 for
循环一层一层向外扩散的逻辑和 visited
集合防止走回头路的逻辑,当你每次从队列中拿出节点 cur
的时候,从 start
到 cur
的最短权重就是 step
记录的步数。
但是,到了「加权图」的场景,事情就没有这么简单了,因为你不能默认每条边的「权重」都是 1 了,这个权重可以是任意正数(Dijkstra 算法要求不能存在负权重边),比如下图的例子:
如果沿用 BFS 算法中的 step
变量记录「步数」,显然红色路径一步就可以走到终点,但是这一步的权重很大;正确的最小权重路径应该是绿色的路径,虽然需要走很多步,但是路径权重依然很小。
其实 Dijkstra 和 BFS 算法差不多,不过在讲解 Dijkstra 算法框架之前,我们首先需要对之前的框架进行如下改造:
想办法去掉 while
循环里面的 for
循环。
有了刚才的铺垫,这个不难理解,刚才说 for
循环是干什么用的来着?
是为了让二叉树一层一层往下遍历,让 BFS 算法一步一步向外扩散,因为这个层数 depth
,或者这个步数 step
,在之前的场景中有用。
但现在我们想解决「加权图」中的最短路径问题,「步数」已经没有参考意义了,「路径的权重之和」才有意义,所以这个 for
循环可以被去掉。
怎么去掉?就拿二叉树的层级遍历来说,其实你可以直接去掉 for
循环相关的代码:
// Input the root node of a binary tree, traverse all nodes of this binary tree
void levelTraverse(TreeNode root) {
if (root == null) return 0;
Queue<TreeNode> q = new LinkedList<>();
q.offer(root);
// Traverse each node of the binary tree
while (!q.isEmpty()) {
TreeNode cur = q.poll();
printf("我不知道节点 %s 在第几层", cur);
// I don't know which level the node %s is at
// Put child nodes into the queue
if (cur.left != null) {
q.offer(cur.left);
}
if (cur.right != null) {
q.offer(cur.right);
}
}
}
// Input the root node of a binary tree and traverse all nodes of the tree
void levelTraverse(TreeNode* root) {
if (root == nullptr) return 0;
queue<TreeNode*> q;
q.push(root);
// Traverse every node of the binary tree
while (!q.empty()) {
TreeNode* cur = q.front();
printf("我不知道节点 %s 在第几层", cur);
// Put the child nodes into the queue
if (cur->left != nullptr) {
q.push(cur->left);
}
if (cur->right != nullptr) {
q.push(cur->right);
}
q.pop();
}
}
# Input the root node of a binary tree, traverse all nodes of this binary tree
def levelTraverse(root: TreeNode) -> None:
if not root:
return
q = []
q.append(root)
# Traverse each node of the binary tree
while q:
cur = q.pop(0)
print(f"我不知道节点 {cur} 在第几层")
# I don't know which level node {cur} is at
# Put the child nodes into the queue
if cur.left:
q.append(cur.left)
if cur.right:
q.append(cur.right)
// Input the root node of a binary tree, traverse all nodes of this binary tree
func levelTraverse(root *TreeNode) {
if root == nil {
return
}
q := []*TreeNode{root}
// Traverse each node of the binary tree
for len(q) > 0 {
cur := q[0]
q = q[1:]
printf("我不知道节点 %s 在第几层", cur)
// I don't know which level node %s is on
// Put child nodes into the queue
if cur.left != nil {
q = append(q, cur.left)
}
if cur.right != nil {
q = append(q, cur.right)
}
}
}
function levelTraverse(root) {
if (root == null) return 0;
var q = [];
q.push(root);
while (q.length > 0) {
var cur = q.shift();
console.log("我不知道节点 " + cur + " 在第几层");
if (cur.left != null) {
q.push(cur.left);
}
if (cur.right != null) {
q.push(cur.right);
}
}
}
但问题是,没有 for
循环,你也没办法维护 depth
变量了。
如果你想同时维护 depth
变量,让每个节点 cur
知道自己在第几层,可以想其他办法,比如新建一个 State
类,记录每个节点所在的层数:
class State {
// record the depth of the node
int depth;
TreeNode node;
State(TreeNode node, int depth) {
this.depth = depth;
this.node = node;
}
}
// input the root node of a binary tree and traverse all nodes of this binary tree
void levelTraverse(TreeNode root) {
if (root == null) return 0;
Queue<State> q = new LinkedList<>();
q.offer(new State(root, 1));
// traverse each node of the binary tree
while (!q.isEmpty()) {
State cur = q.poll();
TreeNode cur_node = cur.node;
int cur_depth = cur.depth;
printf("节点 %s 在第 %s 层", cur_node, cur_depth);
// put the child nodes into the queue
if (cur_node.left != null) {
q.offer(new State(cur_node.left, cur_depth + 1));
}
if (cur_node.right != null) {
q.offer(new State(cur_node.right, cur_depth + 1));
}
}
}
struct State {
// record the depth of the node
int depth;
TreeNode* node;
State(TreeNode* node, int depth) {
this -> depth = depth;
this -> node = node;
}
};
// given the root of a binary tree, traverse all nodes of the tree
void levelTraverse(TreeNode* root) {
if (root == nullptr) return;
queue<State> q;
q.push(State(root, 1));
// traverse every node of the binary tree
while (!q.empty()) {
State cur = q.front();
q.pop();
TreeNode* cur_node = cur.node;
int cur_depth = cur.depth;
printf("节点 %s 在第 %s 层", cur_node, cur_depth);
// put the child nodes into the queue
if (cur_node -> left != nullptr) {
q.push(State(cur_node -> left, cur_depth + 1));
}
if (cur_node -> right != nullptr) {
q.push(State(cur_node -> right, cur_depth + 1));
}
}
}
class State:
def __init__(self, node: TreeNode, depth: int):
self.depth = depth
self.node = node
def levelTraverse(root: TreeNode) -> None:
if not root:
return 0
q = []
q.append(State(root, 1))
while q:
cur = q.pop(0)
cur_node = cur.node
cur_depth = cur.depth
print(f"节点 {cur_node} 在第 {cur_depth} 层")
if cur_node.left:
q.append(State(cur_node.left, cur_depth + 1))
if cur_node.right:
q.append(State(cur_node.right, cur_depth + 1))
type State struct {
depth int
node *TreeNode
}
func levelTraverse(root *TreeNode) {
if root == nil {
return 0
}
q := make([]State, 0)
q = append(q, State{root, 1})
for len(q) > 0 {
cur := q[0]
q = q[1:]
cur_node := cur.node
cur_depth := cur.depth
printf("节点 %s 在第 %s 层", cur_node, cur_depth)
if cur_node.left != nil {
q = append(q, State{cur_node.left, cur_depth + 1})
}
if cur_node.right != nil {
q = append(q, State{cur_node.right, cur_depth + 1})
}
}
}
var State = function(node, depth) {
// record the depth of the node
this.depth = depth;
this.node = node;
};
// input the root node of a binary tree, traverse all nodes of this tree
function levelTraverse(root) {
if (root === null) return 0;
var q = [];
q.push(new State(root, 1));
// traverse each node of the binary tree
while (q.length > 0) {
var cur = q.shift();
var cur_node = cur.node;
var cur_depth = cur.depth;
console.log("节点 " + cur_node + " 在第 " + cur_depth + " 层");
// put child nodes into the queue
if (cur_node.left != null) {
q.push(new State(cur_node.left, cur_depth + 1));
}
if (cur_node.right != null) {
q.push(new State(cur_node.right, cur_depth + 1));
}
}
}
这样,我们就可以不使用 for
循环也确切地知道每个二叉树节点的深度了。
如果你能够理解上面这段代码,我们就可以来看 Dijkstra 算法的代码框架了。