Intro Thoughts

New Stat functionality can help ggplot2 become more expressive and more of a joy to use. However, guidance for creating new layers has typically involved a fair amount of ‘scaffolding’ code, which can be hard to manage and can make scripts less readable.

This exercise explores some new strategies for using new Stat definitions.

We show two ‘express’ extension strategies, one that requires no dependencies and another that uses the statexpress package to define user-facing functions. Both aim at allowing authors to skip the ‘scaffolding/boiler plate’ code, and focus their attention on the compute they wish to use to get from input data to render-ready.

We start with ‘base’ ggplot2 build, and turn to on-the-fly extension.

Step 0. Status Quo, no use of extension mechanisms (ggproto)

Inspiration, and data, for this exercise comes from… https://rfortherestofus.com/2024/07/population-pyramid-part-1

We build the plot in single panel (no patchwork) as used in the blog post. We’ll use the computation lessons in the compute in the extensions in the next section.

library(tidyverse)


oregon_population_pyramid_data <-
  read_csv("https://raw.githubusercontent.com/rfortherestofus/blog/main/population-pyramid-part-1/oregon_population_pyramid_data.csv")
## Rows: 1296 Columns: 4
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (3): county, age, gender
## dbl (1): percent
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
sep = 0
sep = .01

# we will use GeomRect for it's flexibility instead of Col or Bar
oregon_population_pyramid_data %>% 
  filter(county == "Baker") %>%
  # move 'zero' out from center
  mutate(zero_shift = ifelse(gender == "Men", sep, -sep)) %>% 
  mutate(xmin = 0 + zero_shift) %>% 
  # max needs to be flipped, and shifted 
  mutate(xmax = percent * ifelse(gender == "Men", 1, -1) + zero_shift) %>% 
  mutate(y = as.numeric(fct_inorder(age))) %>% 
  # we need to compute ymax and min because we are using GeomRect
  mutate(ymax = y + .45 * 1) %>% 
  mutate(ymin = y - .45 * 1) %>% 
  ggplot() + 
  aes(y = fct_inorder(age), x = percent) + 
  geom_point() +
  geom_rect(aes(xmin = xmin, xmax = xmax, 
                ymin = ymin, ymax = ymax)) + 
  aes(fill = gender) + 
  aes(label = age) +
  geom_text(aes(x = 0, fill = NULL), data = . %>% distinct(age))

Option 1. dependency-free, light-weight layer definition

Use ggproto to define new Stat, then use that directly with a user-facing geom_*() function, replacing the Stat argument.

Create new Stat (upper-case)

compute_panel_pyramid <- function(data, scales, sep = 0, neg_cat = NULL){
  
  if(is.null(neg_cat)){neg_cat <- sort(data$pyramid_cat)[1] }
  
  data %>% 
  mutate(split = ifelse(pyramid_cat == neg_cat, -sep, sep)) %>% 
  mutate(xmin = 0 + split) %>% 
  mutate(xmax = x * ifelse(pyramid_cat == neg_cat, -1, 1) + split) %>% 
  mutate(y = as.numeric(as.factor(y))) %>% 
  mutate(ymax = y + .45*1) %>% 
  mutate(ymin = y - .45*1)
  
}

StatPyramid <- ggproto("StatPyramid", 
                       Stat,
                       compute_panel = compute_panel_pyramid,
                       default_aes = aes(fill = after_stat(pyramid_cat)))


compute_pyramid_label <- function(data, scales){
  
  data %>% 
    distinct(y, label) %>% 
    mutate(x = 0)
  
}

StatPyramidlabel <- ggproto("StatPyramidlabel",
                            Stat,
                            compute_panel = compute_pyramid_label)

use new Stats with existing geom_*() functions

oregon_population_pyramid_data %>%
  filter(county == "Baker") %>%
  ggplot() + 
  aes(x = percent*100, 
      y = fct_inorder(age), 
      pyramid_cat = gender) +
  geom_rect(stat = StatPyramid,
            sep = 1) + 
  aes(label = age)  + 
  geom_text(stat = StatPyramidlabel) + 
  scale_x_continuous(breaks = c(-10:0 - 1, 0:10 + 1),
                     labels = paste0(c(10:0, 0:10), "%"),
                     limits = c(-10,10)) + 
  labs(y = NULL, x = NULL, fill = NULL) + 
  theme(axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.position = "top",
        legend.justification = "left", 
        panel.grid = element_blank(),
        panel.background = element_rect(fill = "whitesmoke")
        )

Option 2. Use statexpress::stat_panel() to define user-facing functions

The statexpress package creates StatTemp in the background and wraps up the ‘scaffolding’ that is more typically used in defining user-facing functions. You won’t be able to use StatTemp itself.

use statexpress to build functions

compute_panel_pyramid <- function(data, scales, sep = 0, neg_cat = NULL){
  
  if(is.null(neg_cat)){neg_cat <- sort(data$pyramid_cat)[1] }
  
  data %>% 
  mutate(split = ifelse(pyramid_cat == neg_cat, -sep, sep)) %>% 
  mutate(xmin = 0 + split) %>% 
  mutate(xmax = x * ifelse(pyramid_cat == neg_cat, -1, 1) + split) %>% 
  mutate(y = as.numeric(as.factor(y))) %>% 
  mutate(ymax = y + .45*1) %>% 
  mutate(ymin = y - .45*1)
  
}

geom_split_pyramid <- function(...){
  statexpress::stat_panel(fun = compute_panel_pyramid, 
                          geom = "rect",
                          default_aes = aes(fill = after_stat(pyramid_cat)),
                          ...
                          )}


compute_pyramid_label <- function(data, scales){
  
  data %>% 
    distinct(y, label) %>% 
    mutate(x = 0)
  
}

geom_split_pyramid_label <- function(...){
  statexpress::stat_panel(fun = compute_pyramid_label, 
                          geom = "text",
                          ...)
}

# ```
# 
# 
# ## test them in plot
# 
# ```{r or_pyramid2}
oregon_population_pyramid_data %>%
  filter(county == "Baker") %>%
  ggplot() + 
  aes(x = percent*100, 
      y = fct_inorder(age), 
      pyramid_cat = gender) +
  geom_split_pyramid(sep = 1) + 
  aes(label = age)  + 
  geom_split_pyramid_label() ->
base; base

# styling
base + 
  scale_x_continuous(breaks = c(-10:0 - 1, 0:10 + 1),
                     labels = paste0(c(10:0, 0:10), "%"),
                     limits = c(-10,10)) + 
  labs(y = NULL, x = NULL, fill = NULL) + 
  theme(axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.position = "top",
        legend.justification = "left", 
        panel.grid = element_blank(),
        panel.background = element_rect(fill = "whitesmoke")
        )

layer_data(i = 2)

To do, notes


Much simpler examples of ‘express stats’

Option 1. No dependency

library(ggplot2)
StatMeans <- ggproto("StatMeans",
                     Stat,
                     compute_group = function(data,scales){
                       x <- mean(data$x)
                       y <- mean(data$y)
                       data.frame(x,y)
                     })

ggplot(cars, aes(dist, speed)) + 
  geom_point() + 
  geom_point(stat = StatMeans, size = 8)

Option 2. {statexpress}

library(tidyverse)

library(statexpress)
geom_xmean_line <- function(...){
  stat_group(function(df) df |> 
               summarize(xintercept = mean(x)), 
             "vline", dropped_aes = c("x", "y"), ...)}


ggplot(cars) +
  aes(speed, dist) + 
  geom_point() + 
  geom_xmean_line(linetype = 'dashed')

last_plot() + 
  aes(color = dist > 50)

For reprex

library(tidyverse)

oregon_population_pyramid_data <-
  read_csv("https://raw.githubusercontent.com/rfortherestofus/blog/main/population-pyramid-part-1/oregon_population_pyramid_data.csv")
## Rows: 1296 Columns: 4
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (3): county, age, gender
## dbl (1): percent
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
compute_panel_pyramid <- function(data, scales, sep = 0, neg_cat = NULL){
  
  if(is.null(neg_cat)){neg_cat <- sort(data$pyramid_cat)[1] }
  
  data %>% 
  mutate(split = ifelse(pyramid_cat == neg_cat, -sep, sep)) %>% 
  mutate(xmin = 0 + split) %>% 
  mutate(xmax = x * ifelse(pyramid_cat == neg_cat, -1, 1) + split) %>% 
  mutate(y = as.numeric(as.factor(y))) %>% 
  mutate(ymax = y + .45*1) %>% 
  mutate(ymin = y - .45*1)
  
}

StatPyramid <- ggproto("StatPyramid", 
                       Stat,
                       compute_panel = compute_panel_pyramid,
                       default_aes = aes(fill = after_stat(pyramid_cat)))


compute_pyramid_label <- function(data, scales){
  
  data %>% 
    distinct(y, label) %>% 
    mutate(x = 0)
  
}


StatPyramidlabel <- ggproto("StatPyramidlabel",
                            Stat,
                            compute_panel = compute_pyramid_label)


scale_x_pyramid <- function(sep = 1, breaks = 0:7, ...){
  
    scale_x_continuous(breaks = c(-rev(breaks) - sep, breaks + 1),
                     labels = paste0(c(rev(breaks), breaks), "%"),
                     limits = c(-max(breaks),max(breaks)), ...)
  
}
  
  
compute_panel_pyramid_no_cat <- function(data, scales, sep = 0, neg_cat = NULL){
  
  if(is.null(neg_cat)){neg_cat <- sort(data$pyramid_cat)[1] }
  
  data %>% 
  mutate(x = mean(x), .by = y) %>% 
  mutate(split = ifelse(pyramid_cat == neg_cat, -sep, sep)) %>% 
  mutate(xmin = 0 + split) %>% 
  mutate(xmax = x * ifelse(pyramid_cat == neg_cat, -1, 1) + split) %>% 
  mutate(y = as.numeric(as.factor(y))) %>% 
  mutate(ymax = y + .45*1) %>% 
  mutate(ymin = y - .45*1)
  
}

StatPyramidequivilant <- ggproto("StatPyramidequivilant",
                                 Stat,
                                 compute_panel = compute_panel_pyramid_no_cat)


compute_panel_pyramid_minimum <- function(data, scales, sep = 0, neg_cat = NULL){
  
  if(is.null(neg_cat)){neg_cat <- sort(data$pyramid_cat)[1] }
  
  data %>% 
  mutate(x = min(x), .by = y) %>% 
  mutate(split = ifelse(pyramid_cat == neg_cat, -sep, sep)) %>% 
  mutate(xmin = 0 + split) %>% 
  mutate(xmax = x * ifelse(pyramid_cat == neg_cat, -1, 1) + split) %>% 
  mutate(y = as.numeric(as.factor(y))) %>% 
  mutate(ymax = y + .45*1) %>% 
  mutate(ymin = y - .45*1)
  
}


StatPyramidequivilantminimum <- ggproto("StatPyramidequivilantminimum",
                                 Stat,
                                 compute_panel = compute_panel_pyramid_minimum)
oregon_population_pyramid_data %>%
  filter(county == "Baker") %>%
  ggplot() + 
  aes(x = percent*100, 
      y = fct_inorder(age), 
      pyramid_cat = gender) +
  geom_rect(stat = StatPyramidequivilant, sep = 1,
            fill = "lightgrey",
            # linetype = "dotted",
            aes(color = "expected, were there no Male-Female difference"),
            linewidth = .2) + 
  aes(label = age)  + 
  geom_text(stat = StatPyramidlabel) + 
  scale_x_pyramid() +
  labs(y = NULL, x = NULL, fill = NULL) + 
  theme(axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.position = "top",
        legend.justification = "left", 
        panel.grid = element_blank(),
        panel.background = element_rect(fill = "whitesmoke")
        ) + 
  scale_color_manual(values = "darkgrey") + 
  labs(color = NULL)

last_plot() +
  geom_rect(stat = StatPyramid,
            sep = 1,
            linewidth = .2) 

oregon_population_pyramid_data %>%
  filter(county == "Baker") %>%
  ggplot() + 
  aes(x = percent*100, 
      y = fct_inorder(age), 
      pyramid_cat = gender) +
  geom_rect(stat = StatPyramid, sep = 1,
            linewidth = .2) + 
  aes(label = age)  + 
  geom_text(stat = StatPyramidlabel) + 
  scale_x_pyramid() +
  labs(y = NULL, x = NULL, fill = NULL) + 
  theme(axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.position = "top",
        legend.justification = "left", 
        panel.grid = element_blank(),
        panel.background = element_rect(fill = "whitesmoke")
        ) + 
  scale_color_manual(values = "darkgrey") + 
  labs(color = NULL)

last_plot() + 
  geom_rect(stat = StatPyramidequivilantminimum, sep = 1,
            alpha = .45,
            fill = "white")