k Medoids Clustering in Python
In this tutorial, we will learn how to perform K medoids clustering in Python.
First, let’s import the required libraries.
import numpy as np from scipy.spatial.distance import cdist import matplotlib.pyplot as plt from sklearn.datasets import make_blobs
Now we will write a function for K Medoids clustering. initially, we will select k random points as the medoids from given data points and associate each data point to the medoids by using the Euclidean distance metric then calculate the cost of which point has a minimum cost we will update it as a new medoid. At last, when there is no change in updation select those points as final medoids
def kmedoids(X, k, max_iterations=100): n_samples = X.shape[0] medoid_indices = np.random.choice(n_samples, k, replace=False) medoids = X[medoid_indices] for _ in range(max_iterations): # Compute distances between all points and medoids distances = cdist(X, medoids, metric='euclidean') # Assign each point to the nearest medoid labels = np.argmin(distances, axis=1) new_medoids = np.copy(medoids) for i in range(k): cluster_points = X[labels == i] if len(cluster_points) == 0: continue # Compute the cost for each point in the cluster to be a medoid cost = cdist(cluster_points, cluster_points, metric='euclidean').sum(axis=1) # Choose the point with the minimum cost as the new medoid new_medoids[i] = cluster_points[np.argmin(cost)] # if there is no change in medoids then break and return if np.all(new_medoids == medoids): break medoids = new_medoids return labels, medoids
Now, we will take some data points
X, _ = make_blobs(n_samples= 1000, centers=5, cluster_std=0.60, random_state=0)
Now, we will perform clustering by calling the function
k = 5 labels, medoids = kmedoids(X, k)
plot the result and cluster.
plt.figure(figsize=(6, 4)) scatter = plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis') plt.scatter(medoids[:, 0], medoids[:, 1], c='red', s=50, marker='*', label='Medoids') plt.title('K-Medoids Clustering') plt.xlabel('X') plt.ylabel('Y') plt.legend() plt.colorbar(scatter, label='Cluster') plt.show()
Leave a Reply