Demonstrate Prim’s Minimum Spanning Tree Algorithm in Python

In this tutorial, you will learn Prim’s minimum spanning tree algorithm in Python. A spanning tree is a subset of a graph with all vertices contained in such a way that it consists of minimum number of edges.

Prim’s Minimum Spanning Tree Algorithm

Prim’s minimum spanning tree: Prim’s algorithm is based on the Greedy algorithm. The greedy algorithm can be any algorithm that follows making the most optimal choice at every stage. At starting we consider a null tree. Prim’s mechanism works by maintaining two lists. One store all the vertices which are already included in the minimum spanning tree while other stores vertices which are not present. At every step, the edges connecting the two lists are considered and the edge with the minimum cost or weight is chosen. After every decision, it updates the two lists. This process continues until the second list is found empty ie all the vertices are been included in the minimum spanning tree.

The main concept behind a spanning tree is to connect all the vertices of the tree to form a tree with minimum cost or weight.


  1. Define a list present that stores all vertices present in the minimum spanning tree.
  2. Assign keys to all vertices and initialize them to infinity.
  3. Let starting vertex be m.
  4. Key of m =o.
  5. Continue until present does not contain all vertices:
    1. Take vertex n which is not in present and has minimum cost.
    2. Include n is present.
    3. Update key for the adjacent vertex of n with minimum cost.

Key is used to indicate if a vertex is present in minimum spanning tree or not. Let’s have a look at code for the minimum spanning tree using Prim’s algorithm.

See our Python code below for Prim’s minimum spanning tree algorithm

import sys
class Graph(): 
    def __init__(self, vertices): 
        self.V = vertices 
        self.graph = [[0 for column in range(vertices)]  
                    for row in range(vertices)] 
    def printTree(self, parent): 
        print("Edge \tWeight")
        for i in range(1, self.V): 
            print(parent[i], "-", i, "\t", self.graph[i][ parent[i] ])
    def min_Key(self, key, mstSet): 
        min = sys.maxint 
        for v in range(self.V): 
            if key[v] < min and mstSet[v] == False: 
                min = key[v] 
                min_index = v 
        return min_index 
    def prim(self): 
        key = [sys.maxint] * self.V 
        parent = [None] * self.V 
        key[0] = 0 
        mstSet = [False] * self.V 
        parent[0] = -1 
        for cout in range(self.V):  
            u = self.min_Key(key, mstSet)  
            mstSet[u] = True
            for v in range(self.V):  
                if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]: 
                        key[v] = self.graph[u][v] 
                        parent[v] = u 
g = Graph(5) 
g.graph = [ [0, 0, 4, 0, 0], 
            [0, 0, 5, 3, 0], 
            [4, 5, 0, 0, 0], 
            [0, 3, 0, 0, 2], 
            [0, 0, 0, 2, 0]] 

You may also see:

Leave a Reply

Your email address will not be published. Required fields are marked *