DS Concepts DS Languages

K-Means Clustering using R

 

K-Means Clustering using R

Hi All!

Today, we will be learning how to perform K-Means Clustering using R to achieve customer segmentation. This case-study comes under unsupervised machine learning (K-Means Clustering).

Problem Statement

Being an owner of a mall, You want to understand who can be my target customers so that you can inform your marketing team regarding your findings and can ask them to devise new strategies accordingly.

Data-set used

We will using Mall-Customers.csv file that can be found here on Kaggle. The dataset has the following variables:

  • CustomerID: Unique ID assigned to the customer
  • Gender: Gender of the customer
  • Age: Age of the customer
  • Annual Income: Annual Income of the customer
  • Spending Score: Score assigned by the mall to the customer based on his/her spending history and sentiments.

Algorithm used

K-Means Clustering

Language used

R

Algorithm Description

  • Choose the number of clusters, k, that you want at the end. This will be the only step defined by user. Rest is the work of algorithm.
  • k objects are selected randomly from the dataset which act as centroids of the k clusters.
  • Cluster Assignment Step: Remaining points/objects in the dataset are grouped to their corresponding clusters by calculating the Euclidean distance between those points and the centroids. The points which have minimum Euclidean distance from the kth centroid are grouped in the kth cluster.
  • Centroid Update Step: Each type a point gets added into a cluster, the centroid is recalculated/updated by keeping added points into consideration. It is to be noted that centroid is the mean of all those points in a specific cluster.
  • The iterations on Centroid Assignment Step and Centroid Update Step keep on executing until a stage is reached where no change takes place in the cluster assignment or convergence is reached. That is, no change takes place on forthcoming iterations.

The Code

Let us begin by reading the csv file.

data = read.csv("Mall_Customers.csv")
head(data)
##   CustomerID Gender Age Annual.Income..k.. Spending.Score..1.100.
## 1          1   Male  19                 15                     39
## 2          2   Male  21                 15                     81
## 3          3 Female  20                 16                      6
## 4          4 Female  23                 16                     77
## 5          5 Female  31                 17                     40
## 6          6 Female  22                 17                     76

Now, let’s see rename the data columns and then see the summary of data.

colnames(data) = c("CustomerID", "Gender", "Age", "Annual_Income", "Spending_Score")
summary(data)
##    CustomerID        Gender         Age        Annual_Income    Spending_Score 
##  Min.   :  1.00   Female:112   Min.   :18.00   Min.   : 15.00   Min.   : 1.00  
##  1st Qu.: 50.75   Male  : 88   1st Qu.:28.75   1st Qu.: 41.50   1st Qu.:34.75  
##  Median :100.50                Median :36.00   Median : 61.50   Median :50.00  
##  Mean   :100.50                Mean   :38.85   Mean   : 60.56   Mean   :50.20  
##  3rd Qu.:150.25                3rd Qu.:49.00   3rd Qu.: 78.00   3rd Qu.:73.00  
##  Max.   :200.00                Max.   :70.00   Max.   :137.00   Max.   :99.00

We can see that the columns don’t contain any missing data. Let’s now do outlier analysis.

box = boxplot(data[, 3:5], col = "lightblue")

1

Points to be noted: * Age is left-skewed, i.e., we are having 50 percent customers of young age ([18, 36] years of age). * Annual_Income is right-skewed and Spending_Score is left-skewed. * We can see that Annual_Income column contains outliers. Let’s handle them.

box$stats
##      [,1]  [,2] [,3]
## [1,] 18.0  15.0  1.0
## [2,] 28.5  41.0 34.5
## [3,] 36.0  61.5 50.0
## [4,] 49.0  78.0 73.0
## [5,] 70.0 126.0 99.0
## attr(,"class")
##       Age 
## "integer"

It is to be noted that column 2 contains the boxplot stats for Annual_Income. Let’s now calculate the quantiles.

quantile(data$Annual_Income, seq(0, 1, 0.02))
##     0%     2%     4%     6%     8%    10%    12%    14%    16%    18%    20% 
##  15.00  16.98  18.96  19.94  20.92  23.90  27.64  28.86  32.52  33.82  37.80 
##    22%    24%    26%    28%    30%    32%    34%    36%    38%    40%    42% 
##  39.00  40.00  42.74  43.72  46.00  47.68  48.00  49.64  54.00  54.00  54.00 
##    44%    46%    48%    50%    52%    54%    56%    58%    60%    62%    64% 
##  57.56  59.54  60.00  61.50  62.00  63.00  63.44  65.00  67.00  69.38  71.00 
##    66%    68%    70%    72%    74%    76%    78%    80%    82%    84%    86% 
##  71.34  73.00  74.30  76.28  77.26  78.00  78.00  78.20  81.72  86.16  87.00 
##    88%    90%    92%    94%    96%    98%   100% 
##  88.00  93.40  98.08 101.12 103.40 120.12 137.00

Thus, we will be replacing values above 126 with 126 which is the maximum value of boxplot for Annual_Income.

data$Annual_Income = ifelse(data$Annual_Income>=126, 126, data$Annual_Income)

Let’s once again plot the boxplots.

box = boxplot(data[, 3:5], col = "lightblue")

2

Let’s add dummy variables for Gender.

data$IsMale = ifelse(data$Gender=="Male", 1, 0)
data$IsFemale = ifelse(data$Gender=="Female", 1, 0)
# Excluding Gender and Customer ID
data = data[, 3:ncol(data)]
head(data)
##   Age Annual_Income Spending_Score IsMale IsFemale
## 1  19            15             39      1        0
## 2  21            15             81      1        0
## 3  20            16              6      0        1
## 4  23            16             77      0        1
## 5  31            17             40      0        1
## 6  22            17             76      0        1

Now, we are convinced that no outliers are there in our data. Let’s now start with clustering. Let’s first scale the dataset. We will be using min-max scaling in our case.

maxs <- apply(data, 2, max) 
mins <- apply(data, 2, min)
data_sc = scale(data, center = mins, scale = maxs - mins)
head(data_sc)
##             Age Annual_Income Spending_Score IsMale IsFemale
## [1,] 0.01923077   0.000000000     0.38775510      1        0
## [2,] 0.05769231   0.000000000     0.81632653      1        0
## [3,] 0.03846154   0.009009009     0.05102041      0        1
## [4,] 0.09615385   0.009009009     0.77551020      0        1
## [5,] 0.25000000   0.018018018     0.39795918      0        1
## [6,] 0.07692308   0.018018018     0.76530612      0        1

In order to create eye-catching visuals, we need factoextra library in R. You can run following commands to install and load that library in your R workspace.

# install.packages("factoextra")
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.5.3
library(factoextra)
## Warning: package 'factoextra' was built under R version 3.5.3
## Welcome! Want to learn more? See two factoextra-related books at https://goo.gl/ve3WBa

Let’s now create wss or within sum of squares plot to find out the optimal number of clusters for this dataset. In this plot, number of clusters are on x-axis and wss values are on y-axis. The point beyond which we see little change in the values of wss are to be chosen as optimal number of clusters.

fviz_nbclust(data_sc, kmeans, method ="wss")

3

We can see that after k = 2, the change becomes very little. Thus, let’s take k = 2.

set.seed(123)
k_means_results <- kmeans(data_sc, 2, nstart = 25)

Above, we have chosen nstart = 25. It is advised to choose large values of nstart = 25 to 50 for better stability in clustering. By default, nstart = 1. nstart is important because this number determines the number of random assignments chosen at the starting of the algorithm. The algorithm will then choose only those assignments at the end which have lowest wss corresponding to them.

# Printing the results
print(k_means_results)
## K-means clustering with 2 clusters of sizes 88, 112
## 
## Cluster means:
##         Age Annual_Income Spending_Score IsMale IsFemale
## 1 0.4193619     0.4232187      0.4848098      1        0
## 2 0.3865041     0.3986486      0.5155794      0        1
## 
## Clustering vector:
##   [1] 1 1 2 2 2 2 2 2 1 2 1 2 2 2 1 1 2 1 1 2 1 1 2 1 2 1 2 1 2 2 1 2 1 1 2 2 2
##  [38] 2 2 2 2 1 1 2 2 2 2 2 2 2 2 1 2 1 2 1 2 1 2 1 1 1 2 2 1 1 2 2 1 2 1 2 2 2
##  [75] 1 1 2 1 2 2 1 1 1 2 2 1 2 2 2 2 2 1 1 2 2 1 2 2 1 1 2 2 1 1 1 2 2 1 1 1 1
## [112] 2 2 1 2 2 2 2 2 2 1 2 2 1 2 2 1 1 1 1 1 1 2 2 1 2 2 1 1 2 2 1 2 2 1 1 1 2
## [149] 2 1 1 1 2 2 2 2 1 2 1 2 2 2 1 2 1 2 1 2 2 1 1 1 1 1 2 2 1 1 1 1 2 2 1 2 2
## [186] 1 2 1 2 2 2 2 1 2 2 2 2 1 1 1
## 
## Within cluster sum of squares by cluster:
## [1] 19.55725 19.37968
##  (between_SS / total_SS =  71.7 %)
## 
## Available components:
## 
## [1] "cluster"      "centers"      "totss"        "withinss"     "tot.withinss"
## [6] "betweenss"    "size"         "iter"         "ifault"

Let’s see what the means of different variables are corresponding to each cluster.

aggregate(data,by=list(cluster=k_means_results$cluster), mean)
##   cluster      Age Annual_Income Spending_Score IsMale IsFemale
## 1       1 39.80682      61.97727       48.51136      1        0
## 2       2 38.09821      59.25000       51.52679      0        1

Adding cluster column to the dataset

data$Clusers = k_means_results$cluster

Accessing results

k_means_results$cluster
##   [1] 1 1 2 2 2 2 2 2 1 2 1 2 2 2 1 1 2 1 1 2 1 1 2 1 2 1 2 1 2 2 1 2 1 1 2 2 2
##  [38] 2 2 2 2 1 1 2 2 2 2 2 2 2 2 1 2 1 2 1 2 1 2 1 1 1 2 2 1 1 2 2 1 2 1 2 2 2
##  [75] 1 1 2 1 2 2 1 1 1 2 2 1 2 2 2 2 2 1 1 2 2 1 2 2 1 1 2 2 1 1 1 2 2 1 1 1 1
## [112] 2 2 1 2 2 2 2 2 2 1 2 2 1 2 2 1 1 1 1 1 1 2 2 1 2 2 1 1 2 2 1 2 2 1 1 1 2
## [149] 2 1 1 1 2 2 2 2 1 2 1 2 2 2 1 2 1 2 1 2 2 1 1 1 1 1 2 2 1 1 1 1 2 2 1 2 2
## [186] 1 2 1 2 2 2 2 1 2 2 2 2 1 1 1

Checking the sizes of the clusters

k_means_results$size
## [1]  88 112

Accessing the centers of the clusters.

k_means_results$centers
##         Age Annual_Income Spending_Score IsMale IsFemale
## 1 0.4193619     0.4232187      0.4848098      1        0
## 2 0.3865041     0.3986486      0.5155794      0        1

Let’s now plot our clusters.

fviz_cluster(k_means_results,data =data,palette =c("#00AFBB", "#FC4E07"),
             ellipse.type ="euclid",
             star.plot =FALSE,
             repel =TRUE,
             ggtheme =theme_minimal()
             )

4

Weaknesses of K-means clustering

  • It requires pre-input from user’s end, i.e., number of clusters is required to be inputted first into the function.After that only, the algorithm can proceed. Solution? Use WSS plot.
  • It is highly sensitive to outliers. Solution? If you’re having lots of outliers, use PAM algorithm.
  • The final result depends heavily on the value provided to nstart. Solution? Assigning higher value to nstart is highly recommended.

So guys, with this, I conclude this tutorial. Stay tuned for more such interesting case-studies. Also, don’t forget to check out our YouTube channel, ML for Analytics.

 

 

 

Leave a Reply

Back To Top
%d bloggers like this: