Intro Thoughts

Campitelli: stat_rasa (now stat_group() and extending to stat_panel())

https://eliocamp.github.io/codigo-r/en/2018/05/how-to-make-a-generic-stat-in-ggplot2/

Backend StatRasa, stat_group()

Define group-wise computation.

library(tidyverse)

compute_group_rasa <- function(data, scales, compute_group_fun, fun.args) {
     # Change default arguments of the function to the 
     # values in fun.args
     args <- formals(compute_group_fun)
     for (i in seq_along(fun.args)) {
        if (names(fun.args[i]) %in% names(fun.args)) {
           args[[names(fun.args[i])]] <- fun.args[[i]]
        } 
     }
     formals(compute_group_fun) <- args
     
     # Apply function to data
     compute_group_fun(data)
}


StatRasagroup <- ggplot2::ggproto(
  `_class` = "StatRasagroup", 
  `_inherit` = ggplot2::Stat,
  compute_group = compute_group_rasa)

# stat function used in ggplot - but reordered from conventional!
stat_group <- function(
                       compute_group_fun = NULL,
                       geom = "point", 
                       mapping = NULL, 
                       data = NULL,
                       position = "identity",
                       required_aes = NULL, 
                       default_aes = NULL, 
                       dropped_aes = NULL,
                      ...,
                      show.legend = NA,
                      inherit.aes = TRUE) {
   # Check arguments 
   if (!is.function(compute_group_fun)) stop("compute_group_fun must be a function")
   
   # Pass dotted arguments to a list
   fun.args <- match.call(expand.dots = FALSE)$`...`
   
   if(!is.null(required_aes)){StatRasagroup$required_aes <- required_aes}
   if(!is.null(default_aes)){StatRasagroup$default_aes <- default_aes}
   if(!is.null(dropped_aes)){StatRasagroup$dropped_aes <- dropped_aes}
   
   ggplot2::layer(
      data = data,
      mapping = mapping,
      stat = StatRasagroup,
      geom = geom,
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      check.aes = FALSE,
      check.param = FALSE,
      params = list(
         compute_group_fun = compute_group_fun, 
         fun.args = fun.args,
         na.rm = FALSE,
         ...
      )
   )
}

extending: Define panel-wise computation

library(tidyverse)

StatRasapanel <- ggplot2::ggproto("StatRasapanel", ggplot2::Stat,
  compute_panel = function(data, scales, compute_panel_fun, fun.args) {
     # Change default arguments of the function to the 
     # values in fun.args
     args <- formals(compute_panel_fun)
     for (i in seq_along(fun.args)) {
        if (names(fun.args[i]) %in% names(fun.args)) {
           args[[names(fun.args[i])]] <- fun.args[[i]]
        } 
     }
     formals(compute_panel_fun) <- args
     
     # Apply function to data
     compute_panel_fun(data)
})

# stat function used in ggplot - we reorder from conventional
stat_panel <- function(compute_panel_fun = NULL, 
                       geom = "point", 
                       mapping = NULL, data = NULL,
                      position = "identity",
                      required_aes = NULL,
                      default_aes = NULL,
                      dropped_aes = NULL,
                      ...,
                      show.legend = NA,
                      inherit.aes = TRUE) {
   # Check arguments 
   if (!is.function(compute_panel_fun)) stop("compute_panel_fun must be a function")
   
   # Pass dotted arguments to a list
   fun.args <- match.call(expand.dots = FALSE)$`...`
   
   if(!is.null(required_aes)){StatRasapanel$required_aes <- required_aes}
   if(!is.null(default_aes)){StatRasapanel$default_aes <- default_aes}
   if(!is.null(dropped_aes)){StatRasapanel$dropped_aes <- dropped_aes}

   ggplot2::layer(
      data = data,
      mapping = mapping,
      stat = StatRasapanel,
      geom = geom,
      position = position,
      show.legend = show.legend,
      inherit.aes = inherit.aes,
      check.aes = FALSE,
      check.param = FALSE,
      params = list(
         compute_panel_fun = compute_panel_fun, 
         fun.args = fun.args,
         na.rm = FALSE,
         ...
      )
   )
}

stat_group(fun = summarize_xy)

group_means <- function(data){
  
 data %>% 
    summarise(x = mean(x),
              y = mean(y))
  
}

geom_means <- function(...){
  stat_group(compute_group_fun = group_means, ...)
}

mtcars |>
  ggplot() + 
  aes(x = wt,
      y = mpg) + 
  geom_point() + 
  geom_means(size = 6)

last_plot() + 
  aes(color = factor(cyl))

group_label_at_center <- function(data){
  
 data %>% 
    group_by(label) %>% 
    summarise(x = mean(x, na.rm = T),
              y = mean(y, na.rm = T))
  
}

geom_center_label <- function(...){
  stat_group(compute_group_fun = group_label_at_center, geom = GeomLabel, ...)
}


palmerpenguins::penguins |>
  ggplot() +
  aes(x = bill_length_mm,
      y = bill_depth_mm) +
  geom_point() +
  aes(label = "All") +
  geom_center_label()
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_point()`).

last_plot() +
  aes(color = species, label = species)
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_point()`).

geom_center_text <- function(...){
  stat_group(compute_group_fun = group_label_at_center, geom = GeomText, ...)
}


palmerpenguins::penguins |>
  ggplot() +
  aes(x = bill_length_mm,
      y = bill_depth_mm) +
  geom_point() +
  aes(color = species) +
  geom_center_text(aes(label = species), 
                    color = "Black", 
                    alpha = .8,
                   size = 5, 
                   fontface = "bold")
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_point()`).

layer_data(i = 2)
##       label        x        y PANEL group colour size angle hjust vjust alpha
## 1    Adelie 38.79139 18.34636     1    -1  Black    5     0   0.5   0.5   0.8
## 2 Chinstrap 48.83382 18.42059     1    -1  Black    5     0   0.5   0.5   0.8
## 3    Gentoo 47.50488 14.98211     1    -1  Black    5     0   0.5   0.5   0.8
##   family fontface lineheight
## 1            bold        1.2
## 2            bold        1.2
## 3            bold        1.2
compute_post <- function(data){
  
  data %>% 
    mutate(xend = x,
           yend = 0)
  
}

geom_post <- function(...){
  stat_group(compute_group_fun = compute_post, geom = "segment", ...)
}


data.frame(outcome = 0:1, prob = c(.4, .6)) |>
  ggplot() + 
  aes(x = outcome,
      y = prob) + 
  geom_post() + 
  geom_point() + 
  labs(title = "probability by outcome")

compute_xmean <- function(data){
  
  data %>% 
    summarize(xintercept = mean(x))
  
}

geom_xmean <- function(...){
  stat_group(compute_group_fun = compute_xmean, geom = "vline", 
             dropped_aes = c("x", "y"), ...)
}

mtcars |>
  ggplot() + 
  aes(x = wt,
      y = mpg) + 
  geom_point() + 
  geom_xmean(linetype = "dashed")

last_plot() + 
  aes(color = factor(cyl))

compute_xy_quantile <- function(data, q = .25){
  
  data %>% 
    summarise(x = quantile(x, q),
              y = quantile(y, q)) 
  
}

geom_quantile <- function(...){
  stat_group(compute_group_fun = compute_xy_quantile, geom = "point", ...)
}

mtcars |>
  ggplot() +
  aes(x = wt,
      y = mpg) +
  geom_point() +
  geom_quantile(size = 8) +
  geom_quantile(size = 8, q = .75)

Stat Rasa one-liners

Since were organize with variable function input in first position and geom in secton position, and we can do one-liners (or two) use positioning for arguments.

geom_xmean_line() in 137 characters

library(tidyverse)
geom_xmean_line <- function(...){stat_group(function(df) df |> summarize(xintercept = mean(x)), "vline", dropped_aes = c("x", "y"), ...)}


ggplot(cars) +
  aes(speed, dist) + 
  geom_point() + 
  geom_xmean_line(linetype = 'dashed')

last_plot() + 
  aes(color = dist > 50)

geom_xmean <- function(...){stat_group(function(df) df |> summarize(x = mean(x), y = I(.05)), "point", ...)}


ggplot(cars) +
  aes(speed, dist) + 
  geom_point() + 
  geom_xmean(size = 8, shape = "diamond") 

geom_post() in 101 characters

geom_post <- function(...){stat_group(function(df) df |> mutate(xend = x, yend = 0), "segment", ...)}

data.frame(prob = c(.4,.6), outcome = c(0, 1)) %>% 
ggplot(data = .) +
  aes(outcome, prob) + 
  geom_post() +
  geom_point() 

geom_expectedvalue <- function(...){stat_group(function(df) df |> summarise(x = sum(x*y), y = 0), "point", ...)} 

last_plot() + 
  geom_expectedvalue()

geom_expectedvalue_label <- function(...){stat_group(function(df) df |> summarise(x = sum(x*y), y = 0) |> mutate(label = round(x, 2)), "text", hjust = 1, vjust = 0, ...)} 

last_plot() +
  geom_expectedvalue_label()

geom_proportion() and geom_proportion_label()

rep(1, 15) |> 
  c(0) %>% 
  data.frame(outcome = .) |>
  ggplot() + 
  aes(x = outcome) + 
  geom_dotplot()
## Bin width defaults to 1/30 of the range of the data. Pick better value with
## `binwidth`.

geom_proportion <- function(...){stat_panel(function(df) df |> summarise(x = sum(x)/length(x), y = 0), "point", ...)}   # this should work for T/F too when rasa_p is in play

last_plot() + 
  geom_proportion()
## Bin width defaults to 1/30 of the range of the data. Pick better value with
## `binwidth`.

geom_proportion_label <- function(...){stat_panel(function(df) df |> summarise(x = sum(x)/length(x), y = 0) |> mutate(label = round(x,2)), vjust = 0, "text", ...)}   # this should work for T/F too when rasa_p is in play

last_plot() + 
  geom_proportion_label()
## Bin width defaults to 1/30 of the range of the data. Pick better value with
## `binwidth`.

# last_plot() + 
#   geom_proportion_label() + 
#   ggsample::facet_bootstrap()
# 
# layer_data(i = 2)
# 
# 
# 
# rep(0:1, 10000) %>% # very large 50/50 sample
#   data.frame(outcome = .) |>
#   ggplot() + 
#   aes(x = outcome) + 
#   geom_dotplot() +
#   geom_proportion() + 
#   geom_proportion_label() + 
#   ggsample::facet_sample(n_facets = 25,
#     n_sampled = 16) ->
# p; p 
#   
# 
# layer_data(p, i = 2) |>
#   ggplot() + 
#   aes(x = x) + 
#   geom_rug() + 
#   geom_histogram()

geom_means in 131 characters

geom_means <- function(...){stat_group(function(df) df |> summarize(x = mean(x, na.rm = T), y = mean(y, na.rm = T)), "point", ...)}

palmerpenguins::penguins %>% 
  ggplot() + 
  aes(bill_length_mm, bill_depth_mm) + 
  geom_point() + 
  geom_means(size = 7)
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_point()`).

## geom_grouplabel_at_means()

geom_grouplabel_at_means <-  function(...){stat_group(function(df) df |> group_by(label) |> summarize(x = mean(x, na.rm = T), y = mean(y, na.rm = T)), "label", ...)}

palmerpenguins::penguins %>% 
  ggplot() + 
  aes(bill_length_mm, bill_depth_mm, label = species, group = species) + 
  geom_point() + 
  geom_grouplabel_at_means(size = 7)
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_point()`).

compute_panel_highlight_lines()

compute_panel_highlight_lines <- function(data, which_id = NULL){

data %>% 
  mutate(ind_id = id %in% which_id) %>% 
  arrange(ind_id) %>%
  mutate(group = fct_inorder(id))
  
}

geom_line_highlight <- function(...){stat_panel(compute_panel_fun = compute_panel_highlight_lines, 
             geom = "line", 
             default_aes = aes(color = after_stat(ind_id)),
             ...)
}

gapminder::gapminder %>% 
  filter(continent == "Americas") %>% 
  ggplot() + 
  aes(x = year, y = lifeExp, id = country) + 
  geom_point() + 
  geom_line_highlight(which_id = "Bolivia",
                      linewidth = 3)

Closing remarks, Other Relevant Work, Caveats

Original example and variant

# using .75 span to match ggplot2 geom_smooth
Detrend <- function(data, method = "lm", span = 0.75) {
   if (method == "lm") {
      data$y <- resid(lm(y ~ x, data = data))
   } else {
      data$y <- resid(loess(y ~ x, span = span, data = data))
   }
   as.data.frame(data)
   }
   
   
library(ggplot2)
set.seed(42)
x <- seq(-1, 3, length.out = 30)
y <- x^2 + rnorm(30)*0.5
df <- data.frame(x = x, y = y)

ggplot(df, aes(x, y)) +
  geom_line(aes(color = "raw data")) +
  geom_smooth(aes(color = "loess smoothing"),
              alpha = .3) + 
  stat_smooth(geom = "point", 
              aes(color = "loess smoothing"),
              xseq = df$x) +
  stat_group(geom = "line", 
             compute_group_fun = Detrend, 
             method = "smooth",
             aes(color = "detrended")) +
  geom_hline(yintercept = 0, 
             linetype = "dashed") + 
  scale_color_discrete(breaks = 
                         fct_inorder(c("raw data",
                                       "loess smoothing",
                                       "detrended"))) +
  labs(title = "detrending with loess smoothing")
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

UI, stat_group -> stat_detrend

stat_detrend <- function(...) {
   stat_group(compute_group_fun = Detrend, geom = "line", ...)
}

ggplot(df, aes(x, y)) + 
  geom_line(aes(color = "raw data")) +
  geom_smooth(method = "lm", 
              aes(color = "linear model")) +
  stat_smooth(method = "lm", geom = "segment",
              xend = df$x, yend = df$y,
              xseq = df$x, alpha = .2
              ) +
  stat_detrend(method = "lm", 
                aes(color = "detrended")) +
  geom_hline(yintercept = 0, 
             linetype = "dashed") + 
  scale_color_discrete(breaks = 
                         fct_inorder(c("raw data",
                                       "linear model",
                                       "detrended"))) + 
  labs(title = "Linear detrending",
       color = NULL)
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'