Find the Minimum Spanning Tree Using Kruskal’s Algorithm

problem-solving
Published

October 29, 2024

Finding the minimum spanning tree (MST) of a graph is a fundamental problem in graph theory with applications in network design, clustering, and more. Kruskal’s algorithm provides an efficient way to solve this problem. This post will walk you through Kruskal’s algorithm and demonstrate its implementation in Python.

Understanding Minimum Spanning Trees

A spanning tree of a connected, undirected graph is a subgraph that connects all vertices without any cycles. A minimum spanning tree is a spanning tree with the minimum possible total weight of its edges, where each edge has an associated weight.

Kruskal’s Algorithm Explained

Kruskal’s algorithm employs a greedy approach. It iteratively adds edges to the MST, always selecting the edge with the smallest weight that doesn’t create a cycle. This continues until all vertices are connected.

Here’s a step-by-step breakdown:

  1. Sort Edges: Sort all edges in the graph in ascending order of their weights.
  2. Initialize MST: Start with an empty minimum spanning tree (MST).
  3. Iterate through Edges: For each edge in the sorted list:
    • Check if adding the edge creates a cycle in the MST. This is typically done using a disjoint-set data structure (also known as a union-find data structure).
    • If adding the edge doesn’t create a cycle, add it to the MST.
  4. Termination: The algorithm terminates when V-1 edges have been added to the MST, where V is the number of vertices in the graph.

Python Implementation

We’ll use a disjoint-set data structure to efficiently check for cycles. Here’s the Python code:

class DisjointSet:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, i):
        if self.parent[i] == i:
            return i
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            if self.rank[root_i] < self.rank[root_j]:
                self.parent[root_i] = root_j
            elif self.rank[root_i] > self.rank[root_j]:
                self.parent[root_j] = root_i
            else:
                self.parent[root_j] = root_i
                self.rank[root_i] += 1
            return True
        return False


def kruskal_mst(graph):
    num_vertices = len(graph)
    edges = []
    for u in range(num_vertices):
        for v, weight in graph[u]:
            edges.append((weight, u, v))  # (weight, u, v)

    edges.sort()
    mst = []
    dset = DisjointSet(num_vertices)
    for weight, u, v in edges:
        if dset.union(u, v):
            mst.append((u, v, weight))

    return mst

# Example graph represented as an adjacency list
graph = [
    [(1, 4), (2, 1)],
    [(0, 4), (2, 2), (3,5)],
    [(0, 1), (1, 2), (3, 8)],
    [(1, 5), (2, 8)]
]

mst = kruskal_mst(graph)
print("Minimum Spanning Tree:")
for u, v, weight in mst:
    print(f"Edge: ({u}, {v}), Weight: {weight}")

This code first defines a DisjointSet class to handle the cycle detection. The kruskal_mst function then implements the algorithm, sorting edges, using the disjoint set to avoid cycles, and building the MST. The example shows how to use the function with a sample graph.

Improving the Code (Optional)

For larger graphs, consider using more optimized data structures or libraries for faster performance. Libraries like NetworkX offer built-in MST algorithms that are highly optimized.