Deep Learning PyTorch Course, K-Means Clustering

The approach to data analysis has significantly changed due to the advancement of deep learning and machine learning. One of these is clustering technology. This post explains how to implement the K-means clustering algorithm in PyTorch and utilize it for data analysis.

1. What is K-means Clustering?

K-means clustering is one of the non-supervised learning algorithms that divides the given data points into K clusters. The goal of this algorithm is to minimize the average distance between each cluster’s centroid and the data points. This means that the data points within a cluster are close to each other, while the distance between clusters is maximized.

2. How K-means Clustering Works

  1. Initialization: Randomly select K cluster centroids.
  2. Assignment Step: Assign each data point to the nearest cluster centroid.
  3. Update Step: Update the centroid of each cluster to the mean of the data points belonging to that cluster.
  4. Convergence Check: If the change in cluster centroids is minimal or none, terminate the algorithm.

This process is repeated to find the optimal clusters.

3. Advantages and Disadvantages of K-means Clustering

Advantages

  • Simple to implement and understand.
  • Efficient with a fast convergence rate.

Disadvantages

  • The K value (number of clusters) must be specified in advance.
  • Does not perform well with non-spherical clusters.
  • Can be sensitive to outliers.

4. Implementing K-means Clustering in PyTorch

Now, let’s implement K-means clustering in PyTorch. In this example, we will generate 2D data for clustering.

4.1. Installing Required Libraries

First, we install and import the necessary libraries.

python
import torch
import numpy as np
import matplotlib.pyplot as plt
    

4.2. Generating Data

We will generate random 2D data.

python
# Generate data
np.random.seed(42)
num_samples_per_cluster = 100
C1 = np.random.randn(num_samples_per_cluster, 2) + np.array([0, 0])
C2 = np.random.randn(num_samples_per_cluster, 2) + np.array([5, 5])
C3 = np.random.randn(num_samples_per_cluster, 2) + np.array([1, 8])

data = np.vstack((C1, C2, C3))
plt.scatter(data[:, 0], data[:, 1])
plt.title("Generated Data")
plt.show()
    

4.3. Implementing the K-means Algorithm

Now, we will implement the K-means algorithm.

python
# K-Mean implementation
def k_means(X, k, num_iters=100):
    # Initialize the centroids of each cluster
    centroids = X[np.random.choice(X.shape[0], k, replace=False)]
    
    for _ in range(num_iters):
        # Assign each data point to the nearest centroid
        distances = torch.cdist(torch.tensor(X, dtype=torch.float32), torch.tensor(centroids, dtype=torch.float32))
        labels = torch.argmin(distances, dim=1)

        # Calculate new centroids
        new_centroids = torch.zeros_like(centroids)
        for i in range(k):
            if torch.any(labels == i):
                new_centroids[i, :] = X[labels.numpy() == i].mean(axis=0)
        
        centroids = new_centroids

    return labels.numpy(), centroids.numpy()
    

4.4. Running the Algorithm

We will perform K-means clustering and visualize the results.

python
# Run K-means
k = 3
labels, centroids = k_means(data, k)

# Visualize the results
plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], s=200, c='red', marker='X')  # Mark centroids
plt.title("K-Means Clustering")
plt.show()
    

5. Applications of K-means Clustering

K-means clustering is used in various fields such as customer segmentation, image compression, and recommendation systems. It also becomes a useful tool for data analysts to understand the structure of data and to discover patterns.

6. Conclusion

K-means clustering is an easy-to-understand clustering algorithm that shows strong performance on suitable data. By implementing it with PyTorch, we have learned the basics of advanced deep learning and machine learning. I hope this course helps you understand the concept of data clustering and familiarize yourself with the structure of PyTorch.

I hope you felt the fun and possibilities of deep learning through all the code and examples. We will cover various topics related to data analysis and deep learning in the future, so please stay tuned. Thank you!