Open In App

Mean Shift Clustering using Sklearn

Last Updated : 18 Dec, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Clustering is a fundamental method in unsupervised device learning, and one powerful set of rules for this venture is Mean Shift clustering. Mean Shift is a technique for grouping comparable data factors into clusters primarily based on their inherent characteristics, with our previous understanding of the number of clusters. This article explores the idea of Mean Shift clustering, together with using the scikit-study library in Python to use this method. We’ll cover key concepts like clustering, Kernel Density Estimation (KDE), and bandwidth, and offer step-by-step commands for acting Mean Shift clustering with the usage of scikit-analyze.

What is Mean-Shift?

Mean Shift is a clustering algorithm used to pick out dense areas in a dataset and assign facts and factors to their respective clusters. It is a non-parametric, density-based clustering technique, which means it does not require any previous assumptions approximately the wide variety of clusters or their shapes. Instead, it discovers clusters based totally on the density of information points within the function area.

  • Mean Shift is a mode-seeking algorithm, which means that it finds the modes (peaks) of the density distribution of the data. This is in contrast to centroid-based clustering algorithms, such as K-Means clustering, which find the centroids of the clusters.
  •  Mean Shift works by iteratively refining the positions of the data points, moving them towards the modes of the density distribution. This process is repeated until the data points converge to the modes of the density distribution.

Key Concepts of Mean-Shift Clustering

1- Kernel Density Estimation (KDE): Kernel Density Estimation (KDE) is a non-parametric statistical method used to estimate the opportunity density function (PDF) of a continuous random variable. It affords a clean and non-stop illustration of the underlying statistics distribution. The fundamental idea at the back of KDE is to vicinity a kernel (a clean, continuous, and symmetric feature, usually a Gaussian) on each information point and sum those kernels to estimate the PDF. The formula for KDE, as mentioned in advance, is:

f_{k}(u) =\frac{1}{nhd}\sum_{i=1}^{n}K(\frac{u-u_{i}}{h})

  • f_{k}(u)   is the estimated PDF at point
  • n is the number of data points.
  • d is the dimensionality of the data.
  • u-ui represents each data point.
  • K is a kernel function.
  • h is the bandwidth, a smoothing parameter.

2- Choosing the Right Bandwidth/Radius: The desire of bandwidth (h) in Mean Shift and KDE is crucial as it notably affects the smoothness and accuracy of the predicted PDF.

There are a number of different ways to select the bandwidth for Mean Shift clustering. Some common approaches include:

  • Scott’s Rule: This rule selects a bandwidth that is proportional to the standard deviation of the data.
  • Silverman’s Rule: This rule selects a bandwidth that is proportional to the median interquartile range of the data.
  • Cross-validation: This approach involves training Mean Shift clustering with a range of different bandwidth values and evaluating the performance of the algorithm on a held-out test set. The bandwidth that results in the best performance on the test set is then selected.
  • Expert Knowledge: If you have got domain-specific information or earlier records approximately your statistics, you can pick the bandwidth manually. Adjusting the bandwidth based totally in your know-how of the facts’s traits can from time to time result in better results.

The need of bandwidth is to get stability between over-smoothing and underneath-smoothing (sensitive to noise). It frequently depends at the unique traits of your records, and experimentation may be required to locate the most appropriate bandwidth on your clustering or density estimation challenge.

3- Convergence: Convergence in Mean Shift occurs when the data points stop moving significantly. This means that the data points have reached the modes of the density distribution and are no longer moving towards higher density regions.

4- Bandwidth Kernel Function: The bandwidth kernel function is a function that is used to weight the data points when calculating the mean shift vector. It controls the size of the window around each data point that Mean Shift uses to calculate the mean shift vector. A larger bandwidth will result in fewer clusters, while a smaller bandwidth will result in more clusters.

5- Mean shift vector: The mean shift vector for a data point points in the direction of the highest density of data points around it. Mean Shift moves the data points in the direction of their mean shift vectors, which ultimately leads to the data points converging to the modes of the density distribution.

How mean-shift works?

Mean Shift operates through a series of steps to identify clusters within a dataset

  1. Kernel Density Estimation: It begins with estimating the PDF of the facts the usage of kernel density estimation. A kernel feature is used to symbolize the distribution of statistics points in the function space. A common desire is the Gaussian kernel. This step helps in figuring out areas of excessive records density.
  2. Mean Shift Vector: For every statistics point, Mean Shift calculates a median shift vector that points towards the mode (peak) of the anticipated PDF. This vector represents the direction wherein the records point need to pass to reach a location of higher density.
  3. Shifting Data Points: Data factors are shifted within the path of their respective mean shift vectors. This step is executed iteratively, and the records points steadily circulate towards the nearest modes (cluster centers) inside the feature space.
  4. Convergence: The technique keeps until information factors now not pass significantly, indicating convergence. A convergence threshold can be defined to manipulate when to forestall the iterations.
  5. Cluster Assignment: Once convergence is accomplished, statistics factors are assigned to their nearest modes, effectively growing clusters.

Mean Shift correctly identifies clusters via iteratively moving information points closer to regions of higher density until they attain modes, which represent the cluster facilities.

Algorithm

The Mean Shift Algorithm follows a specific sequence of operations to move statistics factors closer to the modes of the anticipated density. The steps involved are as follows:

  1. Initialization: Start with hard and fast data points.
  2. Bandwidth Selection: Choose the ideal bandwidth h for the kernel. You can use techniques like Scott’s Rule or Silverman’s Rule or estimate it using move-validation.
  3. Mean Shift Computation: For each data point, compute the mean shift vector M(x).
  4. Shift Data Points: Update the data points by shifting them in the direction of M(x). x←x+M(x).
  5. Convergence Check: Repeat steps three and four until the data points do not converge. You can outline a convergence threshold.
  6. Cluster Assignment: Once the data factors have converged, assign every point to its nearest mode (cluster middle).
  7. Result: The final clusters are determined based on the modes.

The Mean Shift set of rules iteratively refines the positions of statistics points, moving them toward the modes of the estimated density.

Example 1: Basic Mean Shift Clustering

The code example taken here is to illustrate how to use the MeanShift clustering algorithm from the scikit-learn library to cluster synthetic data.

The code first creates a dataset of 300 samples with 3 centers using the make_blobs() function from scikit-learn. Then, it fits the Mean Shift clustering algorithm to the data using the MeanShift() class from scikit-learn. Finally, it visualizes the results using the plt.scatter() function from Matplotlib. The visualization shows the data points colored by their cluster labels. The cluster centers are marked with red crosses.

Python

from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
 
# Create synthetic data
X, _ = make_blobs(n_samples=300, centers=3, random_state=42)
 
# Apply Mean Shift clustering
mean_shift = MeanShift()
mean_shift.fit(X)
 
# Visualize the results
plt.scatter(X[:, 0], X[:, 1], c=mean_shift.labels_, cmap='viridis', marker='o')
plt.scatter(mean_shift.cluster_centers_[:, 0], mean_shift.cluster_centers_[:, 1],
            c='red', marker='x', s=200, linewidths=3, label='Cluster Centers')
plt.title('Mean Shift Clustering')
plt.legend()
plt.show()

                    

Output:

Screenshot-2023-11-13-120910

Example 2: Tuning Bandwidth Parameter

The code example taken here is to illustrate how the impact of the bandwidth parameter on clustering by setting a specific bandwidth value in the Mean Shift algorithm.

The visualization shows how adjusting the bandwidth influences the shape and size of the identified clusters.

Python3

from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
 
# Create synthetic data
X, _ = make_blobs(n_samples=300, centers=3, random_state=42)
 
# Apply Mean Shift clustering with a specific bandwidth
bandwidth = 1.5
mean_shift = MeanShift(bandwidth=bandwidth)
mean_shift.fit(X)
 
# Visualize the results
plt.scatter(X[:, 0], X[:, 1], c=mean_shift.labels_, cmap='viridis', marker='o')
plt.scatter(mean_shift.cluster_centers_[:, 0], mean_shift.cluster_centers_[:, 1],
            c='red', marker='x', s=200, linewidths=3, label='Cluster Centers')
plt.title('Mean Shift Clustering with Bandwidth = {}'.format(bandwidth))
plt.legend()
plt.show()

                    

Output:

Screenshot-2023-11-13-121106

Example 3: Comparing Mean Shift with K-Means

This code compares the Mean Shift and K-Means clustering algorithms on synthetic data.

  • The Mean Shift algorithm is a mode-seeking algorithm that finds the most common values in the data and groups the data points around those values.
  • The K-Means algorithm is a centroid-based algorithm that groups the data points into a predefined number of clusters.

The code first creates synthetic data using the make_blobs() function from scikit-learn. Then, it applies both the Mean Shift and K-Means clustering algorithms to the data. Finally, it visualizes the results using the plt.scatter() function from Matplotlib. The visualization shows the data points colored by their cluster labels. The cluster centers are marked with red crosses.

Python3

from sklearn.cluster import MeanShift, KMeans
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
 
# Create synthetic data
X, _ = make_blobs(n_samples=300, centers=3, random_state=42)
 
# Apply Mean Shift clustering
mean_shift = MeanShift()
mean_shift.fit(X)
 
# Apply K-Means clustering for comparison
kmeans = KMeans(n_clusters=3, random_state=42)
kmeans.fit(X)
 
# Visualize the results
plt.figure(figsize=(12, 5))
 
plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=mean_shift.labels_, cmap='viridis', marker='o')
plt.scatter(mean_shift.cluster_centers_[:, 0], mean_shift.cluster_centers_[:, 1],
            c='red', marker='x', s=200, linewidths=3, label='Cluster Centers')
plt.title('Mean Shift Clustering')
 
plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c=kmeans.labels_, cmap='viridis', marker='o')
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1],
            c='red', marker='x', s=200, linewidths=3, label='Cluster Centers')
plt.title('K-Means Clustering')
 
plt.show()

                    

Output:

Screenshot-2023-11-13-121036

Conclusion

Mean shift clustering stands out as a powerful and versatile tool for clustering tasks. Its non-parametric nature, adaptability to different data types, and ability to handle noise make it a valuable addition to the machine learning toolkit. With its straightforward implementation and wide range of applications, mean shift clustering is a technique worth exploring for various data analysis and pattern recognition tasks.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads