The data given by x are clustered by the k-means method, which aims to partition the points into k groups such that the sum of squares from points to the assigned cluster centres is minimized. At the minimum, all cluster centres are at the mean of their Voronoi sets (the set of data points which are nearest to the cluster centre).
The algorithm of Hartigan and Wong (1979) is used by default. Note that some authors use k-means to refer to a specific algorithm rather than the general method: most commonly the algorithm given by MacQueen (1967) but sometimes that given by Lloyd (1957) and Forgy (1965). The Hartigan–Wong algorithm generally does a better job than either of those, but trying several random starts (nstart > 1) is often recommended. In rare cases, when some of the points (rows of x) are extremely close, the algorithm may not converge in the “Quick-Transfer” stage, signalling a warning (and returning ifault = 4). Slight rounding of the data may be advisable in that case.
For ease of programmatic exploration, k=1 is allowed, notably returning the center and withinss.
Except for the Lloyd–Forgy method, k clusters will always be returned if a number is specified. If an initial matrix of centres is supplied, it is possible that no point will be closest to one or more centres, which is currently an error for the Hartigan–Wong method.
https://www.tidymodels.org/learn/statistics/k-means/index.html#exploratory-clustering
library(tidyverse)
library(tidymodels)
set.seed(27)
centers <- tibble(
cluster = factor(1:3),
num_points = c(100, 150, 50), # number points in each cluster
x1 = c(5, 0, -3), # x1 coordinate of cluster center
x2 = c(-1, 1, -2) # x2 coordinate of cluster center
)
labelled_points <-
centers %>%
mutate(
x1 = map2(num_points, x1, rnorm),
x2 = map2(num_points, x2, rnorm)
) %>%
select(-num_points) %>%
unnest(cols = c(x1, x2))
ggplot(labelled_points) +
aes(x1,
x2,
color = cluster) +
geom_point()
points <-
labelled_points %>%
select(-cluster)
kclust <- kmeans(points, centers = 3)
kclust
## K-means clustering with 3 clusters of sizes 148, 51, 101
##
## Cluster means:
## x1 x2
## 1 0.08853475 1.045461
## 2 -3.14292460 -2.000043
## 3 5.00401249 -1.045811
##
## Clustering vector:
## [1] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [38] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [75] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 1 1 1 1
## [112] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [149] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [186] 1 1 1 1 1 1 1 1 1 1 1 1 1 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [223] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2
## [260] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [297] 2 2 2 2
##
## Within cluster sum of squares by cluster:
## [1] 298.9415 108.8112 243.2092
## (between_SS / total_SS = 82.5 %)
##
## Available components:
##
## [1] "cluster" "centers" "totss" "withinss" "tot.withinss"
## [6] "betweenss" "size" "iter" "ifault"
augment(kclust, points)
## # A tibble: 300 × 3
## x1 x2 .cluster
## <dbl> <dbl> <fct>
## 1 6.91 -2.74 3
## 2 6.14 -2.45 3
## 3 4.24 -0.946 3
## 4 3.54 0.287 3
## 5 3.91 0.408 3
## 6 5.30 -1.58 3
## 7 5.01 -1.77 3
## 8 6.16 -1.68 3
## 9 7.13 -2.17 3
## 10 5.24 -2.42 3
## # ℹ 290 more rows
tidy(kclust)
## # A tibble: 3 × 5
## x1 x2 size withinss cluster
## <dbl> <dbl> <int> <dbl> <fct>
## 1 0.0885 1.05 148 299. 1
## 2 -3.14 -2.00 51 109. 2
## 3 5.00 -1.05 101 243. 3
data_prep <- penguins |>
select(bill_length_mm, bill_depth_mm) |>
remove_missing()
data_prep |>
kmeans(centers = 3) |>
augment(data_prep) |>
ggplot() +
aes(bill_length_mm, bill_depth_mm) +
geom_point() +
aes(color = .cluster)
set_num_centers <- function(num_centers = 3){
options(ggkmeans.num_centers = num_centers)
}
compute_panel_kmeans <- function(data, scales, centers = getOption("ggkmeans.num_centers", 3), seed = 1234){
set.seed(seed)
data_prep <- data |>
mutate(row_id = row_number()) |>
select(x, y) |>
remove_missing()
data_prep |>
kmeans(centers = centers) |>
augment(data_prep)
}
StatKmeans <- ggproto("StatKmeans", Stat,
compute_panel = compute_panel_kmeans,
default_aes = aes(color = after_stat(.cluster)))
#' @export
geom_kmeans <- make_constructor(GeomPoint, stat = StatKmeans)
penguins_clean <- penguins |>
remove_missing()
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.
penguins_clean |>
ggplot() +
aes(bill_length_mm, bill_depth_mm) +
geom_point()
last_plot() +
geom_kmeans()
compute_panel_kmeans_tidy <- function(data, scales, centers = getOption("ggkmeans.num_centers", 3), seed = 1234){
set.seed(seed)
data_prep <- data |>
mutate(row_id = row_number()) |>
select(x, y) |>
remove_missing()
data_prep |>
kmeans(centers = centers) |>
tidy() |>
rename(.size = size)
}
StatKmeansCenters <- ggproto("StatKmeansCenters", Stat,
compute_panel = compute_panel_kmeans_tidy)
GeomPointX <- ggproto("GeomPointX", GeomPoint,
default_aes = GeomPoint$default_aes |>
modifyList(aes(shape = 13, size = 7),
keep.null = T)
)
#' @export
geom_kmeans_center <- make_constructor(GeomPointX, stat = StatKmeansCenters)
last_plot() +
geom_kmeans_center()
compute_panel_kmeans_lengths <- function(data, scales,
centers = getOption("ggkmeans.num_centers", 3), seed = 1234){
set.seed(seed)
data_prep <- data |>
mutate(row_id = row_number()) |>
select(x, y) |>
remove_missing()
kmeaned <- data_prep |>
kmeans(centers = centers)
points <- kmeaned |>
augment(data_prep)
kmeaned |>
tidy() |>
rename(.size = size) |>
rename(.cluster = cluster) |>
rename(xend = x, yend = y) |>
right_join(points)
}
StatKmeansSegments <- ggproto("StatKmeansSegments", Stat,
compute_panel = compute_panel_kmeans_lengths,
default_aes = aes(color = after_stat(.cluster)))
GeomSegmentThin <- ggproto("GeomSegmentThin", GeomSegment,
default_aes = GeomSegment$default_aes |>
modifyList(aes(linewidth = .2, alpha = .5)))
#' @export
geom_kmeans_lengths <- make_constructor(GeomSegmentThin, stat = StatKmeansSegments)
last_plot() +
geom_kmeans_lengths()
## Joining with `by = join_by(.cluster)`
penguins_clean |>
ggplot() +
aes(bill_length_mm, bill_depth_mm) +
geom_point() +
aes(color = species)
set_num_centers(2)
penguins_clean |>
ggplot() +
aes(bill_length_mm, bill_depth_mm) +
geom_kmeans() +
geom_kmeans_center()
set_num_centers(3)
last_plot()
set_num_centers(4)
last_plot()
set_num_centers(5)
last_plot()
PAIR_code_umap_mammoth_url <- "https://raw.githubusercontent.com/PAIR-code/understanding-umap/refs/heads/master/raw_data/mammoth_3d.json"
library(ggdims)
mammoth_df <- PAIR_code_umap_mammoth_url |>
jsonlite::fromJSON() |>
as.data.frame()
set_num_centers(4)
mammoth_df |>
ggplot() +
aes(V1, V2) +
geom_kmeans() +
geom_kmeans_center()
mammoth_df |>
ggplot() +
aes(V1, V2) +
geom_kmeans() +
geom_kmeans_lengths(linewidth = .2, alpha = .2) +
geom_kmeans_center()
## Joining with `by = join_by(.cluster)`
layer_three <- layer_data(i = 3)
## Joining with `by = join_by(.cluster)`
last_plot() +
ggforce::geom_voronoi_segment(data = layer_three, # geom_kmeans_partitian
aes(x = x,
y = y)) +
ggforce::geom_voronoi_tile(data = layer_three |> rename(.cluster = cluster),
aes(x = x,
y = y,
fill = .cluster))
## Joining with `by = join_by(.cluster)`
## Warning: Computation failed in `stat_voronoi_tile()`.
## Caused by error in `deldir::deldir()`:
## ! The x-range of the points is zero, whence a rectangular window
## cannot be inferred from the data. You must specify the rectangular
## window explicitly.
library(tidyverse)
ind_complete <- complete.cases(penguins)
penguins |>
select(bill_len, bill_dep) |>
filter(ind_complete) |>
kmeans(centers = 4) |>
broom::tidy()
penguins |>
select(bill_len, bill_dep) |>
filter(ind_complete) |>
kmeans(centers = 4) |>
broom::augment(
penguins |>
filter(ind_complete)
)