k Medoids Clustering in Python
In this tutorial, we will learn how to perform K medoids clustering in Python.
First, let’s import the required libraries.
import numpy as np from scipy.spatial.distance import cdist import matplotlib.pyplot as plt from sklearn.datasets import make_blobs
Now we will write a function for K Medoids clustering. initially, we will select k random points as the medoids from given data points and associate each data point to the medoids by using the Euclidean distance metric then calculate the cost of which point has a minimum cost we will update it as a new medoid. At last, when there is no change in updation select those points as final medoids
def kmedoids(X, k, max_iterations=100):
n_samples = X.shape[0]
medoid_indices = np.random.choice(n_samples, k, replace=False)
medoids = X[medoid_indices]
for _ in range(max_iterations):
# Compute distances between all points and medoids
distances = cdist(X, medoids, metric='euclidean')
# Assign each point to the nearest medoid
labels = np.argmin(distances, axis=1)
new_medoids = np.copy(medoids)
for i in range(k):
cluster_points = X[labels == i]
if len(cluster_points) == 0:
continue
# Compute the cost for each point in the cluster to be a medoid
cost = cdist(cluster_points, cluster_points, metric='euclidean').sum(axis=1)
# Choose the point with the minimum cost as the new medoid
new_medoids[i] = cluster_points[np.argmin(cost)]
# if there is no change in medoids then break and return
if np.all(new_medoids == medoids):
break
medoids = new_medoids
return labels, medoids
Now, we will take some data points
X, _ = make_blobs(n_samples= 1000, centers=5, cluster_std=0.60, random_state=0)
Now, we will perform clustering by calling the function
k = 5 labels, medoids = kmedoids(X, k)
plot the result and cluster.
plt.figure(figsize=(6, 4))
scatter = plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.scatter(medoids[:, 0], medoids[:, 1], c='red', s=50, marker='*', label='Medoids')
plt.title('K-Medoids Clustering')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.colorbar(scatter, label='Cluster')
plt.show()
Leave a Reply