In ‘Easy Geom Recipes’ the computation we did was within groups, specifying the compute_group parameter within ggproto.
Grouping variables in ggplot2 are character, factor or logical variables.
If we use the compute_group parameter in our ggproto function to define our stat, computation will happen in a group-wise basis if aesthetics-variable pairing happens with such variables.
In other words, our target layer geom_* will divide up our data set by any character, factor, or logical categories, and then do computation before returning a data frame for plotting.
This is not always the desired behavior. Sometimes computation has to happen across groups. The examples that follow are such cases.
First we’ll see a regression model that has an indicator variable as one of its inputs. We want the model to be computed holistically; geom_smooth(), which you may have used, computes groupwise; i.e. one model for each group. We’ll use panel-level computation to display across group model results.
The second example is circle packing. We can use the {circlepack} library to visualize quantities for entities. To wrap this work up into a geom function, we need the computation that positions the circles to happen at the panel level – it shouldn’t do the computation within groups.
Finally, we look at a join of sf geometries as our computation. Here, I think it may be possible to use set compute_group or compute_panel and get similar visual results. However, computing and plotting a panel (we’ll do this with an inner join of an input data set and an sf reference data frame) is much faster than computing group-wise. (just some guesses, need to try with sf w/ fewer geometries than fips.)
class: inverse, middle, center
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.0 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.4.1 ✔ tibble 3.2.0
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(palmerpenguins)
penguins <- remove_missing(penguins)
## Warning: Removed 11 rows containing missing values.
model <- lm(body_mass_g ~ flipper_length_mm + species, data = penguins)
penguins %>%
mutate(fitted = model$fitted.values) ->
penguins_w_fitted
penguins_w_fitted %>%
ggplot() +
aes(x = flipper_length_mm) +
aes(y = body_mass_g) +
aes(color = species) +
geom_point() +
geom_line(aes(y = fitted, group = species))
compute_panel_ols_ind <- function(data, scales, formula = y ~ x + indicator) {
model <- lm(formula = formula,
data = data)
data.frame(x = data$x,
y = model$fitted.values,
indicator = data$indicator,
xend = data$x, # for residuals geom
yend = data$y # for residuals geom
)
}
penguins %>%
rename(x = bill_length_mm,
y = bill_depth_mm,
indicator = species) %>%
select(x, y, indicator) %>%
remove_missing() %>% # ggplot2
compute_panel_ols_ind() %>%
tibble() # for nicer display
## # A tibble: 333 × 5
## x y indicator xend yend
## <dbl> <dbl> <fct> <dbl> <dbl>
## 1 39.1 18.4 Adelie 39.1 18.7
## 2 39.5 18.5 Adelie 39.5 17.4
## 3 40.3 18.6 Adelie 40.3 18
## 4 36.7 17.9 Adelie 36.7 19.3
## 5 39.3 18.4 Adelie 39.3 20.6
## 6 38.9 18.4 Adelie 38.9 17.8
## 7 39.2 18.4 Adelie 39.2 19.6
## 8 41.1 18.8 Adelie 41.1 17.6
## 9 38.6 18.3 Adelie 38.6 21.2
## 10 34.6 17.5 Adelie 34.6 21.1
## # … with 323 more rows
–
StatLmindicator <- ggplot2::ggproto("StatLmindicator",
ggplot2::Stat,
compute_panel = compute_panel_ols_ind,
required_aes = c("x", "y", "indicator"),
default_aes = ggplot2::aes(group = ggplot2::after_stat(indicator)
))
–
geom_lm_indicator <- function(mapping = NULL, data = NULL,
position = "identity", na.rm = FALSE,
show.legend = NA,
inherit.aes = TRUE, ...) {
ggplot2::layer(
stat = StatLmindicator, # proto object from Step 2
geom = ggplot2::GeomLine, # inherit other behavior
data = data,
mapping = mapping,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = list(na.rm = na.rm, ...)
)
}
ggplot(palmerpenguins::penguins) +
aes(x = flipper_length_mm ) +
aes(y = body_mass_g ) +
geom_point() +
aes(color = species) +
aes(indicator = species) +
geom_lm_indicator()
## Warning: Removed 2 rows containing non-finite values (`stat_lmindicator()`).
## Warning: Removed 2 rows containing missing values (`geom_point()`).
ggplot(palmerpenguins::penguins) +
aes(x = flipper_length_mm ) +
aes(y = body_mass_g ) +
geom_point() +
aes(color = species) +
aes(indicator = species) +
geom_lm_indicator(formula = y ~ x * indicator)
## Warning: Removed 2 rows containing non-finite values (`stat_lmindicator()`).
## Removed 2 rows containing missing values (`geom_point()`).
ggplot(palmerpenguins::penguins) +
aes(x = flipper_length_mm ) +
aes(y = body_mass_g ) +
geom_point() +
aes(color = species) +
aes(indicator = species) +
geom_lm_indicator(formula = y ~ I(x^3) + I(x^2) + x + indicator)
## Warning: Removed 2 rows containing non-finite values (`stat_lmindicator()`).
## Removed 2 rows containing missing values (`geom_point()`).
ggplot(palmerpenguins::penguins) +
aes(x = flipper_length_mm ) +
aes(y = body_mass_g ) +
geom_point() +
aes(color = species) +
aes(indicator = species) +
geom_lm_indicator(formula = y ~ x + indicator) +
facet_wrap(~sex)
## Warning: Removed 2 rows containing non-finite values (`stat_lmindicator()`).
## Warning: Removed 2 rows containing missing values (`geom_point()`).
Create a residuals layer analogue.
library(tidyverse)
library(gapminder)
gapminder %>%
filter(continent == "Americas") %>%
filter(year == 2002) %>%
select(country, pop) ->
prep
packcircles::circleProgressiveLayout(prep$pop,
sizetype = 'area') ->
pack
cbind(prep, pack) %>%
mutate(id = row_number()) %>%
tibble() # for nicer display
## # A tibble: 25 × 6
## country pop x y radius id
## <fct> <int> <dbl> <dbl> <dbl> <int>
## 1 Argentina 38331121 -3493. 0 3493. 1
## 2 Bolivia 8445134 1640. 0 1640. 2
## 3 Brazil 179914212 2733. -9142. 7568. 3
## 4 Canada 31902268 1151. 4801. 3187. 4
## 5 Chile 15497046 5274. 1302. 2221. 5
## 6 Colombia 41008227 10562. -1161. 3613. 6
## 7 Costa Rica 3834934 -4573. -4469. 1105. 7
## 8 Cuba 11226999 -7453. -3647. 1890. 8
## 9 Dominican Republic 8650322 -8637. -300. 1659. 9
## 10 Ecuador 12921234 -7908. 3315. 2028. 10
## # … with 15 more rows
pack %>%
packcircles::circleLayoutVertices(npoints = 50) ->
circle_outlines
circle_outlines %>%
ggplot() +
aes(x = x, y = y) +
geom_polygon(colour = "black", alpha = 0.6) +
aes(group = id) +
aes(fill = factor(id)) +
geom_text(data = cbind(prep, pack),
aes(x, y, size = pop, label = country,
group = NULL, fill = NULL)) +
theme(legend.position = "none") +
coord_equal()
# you won't use the scales argument, but ggplot will later
compute_panel_circle_pack <- function(data, scales){
data %>%
mutate(id = row_number()) ->
data1
if(is.null(data$area)){
data1 %>%
mutate(area = 1) ->
data1
}
data1 %>%
pull(area) %>%
packcircles::circleProgressiveLayout(
sizetype = 'area') %>%
packcircles::circleLayoutVertices(npoints = 300) %>%
left_join(data1) #%>%
# rename(group = id)
}
# step 1b test the computation function
gapminder::gapminder %>%
filter(continent == "Americas") %>%
filter(year == 2002) %>%
# input must have required aesthetic inputs as columns
rename(area = pop) %>%
compute_panel_circle_pack() %>%
head()
## Joining with `by = join_by(id)`
## x y id country continent year lifeExp area gdpPercap
## 1 0.0000000 0.00000 1 Argentina Americas 2002 74.34 38331121 8797.641
## 2 -0.7660766 73.15225 1 Argentina Americas 2002 74.34 38331121 8797.641
## 3 -3.0639703 146.27241 1 Argentina Americas 2002 74.34 38331121 8797.641
## 4 -6.8926731 219.32842 1 Argentina Americas 2002 74.34 38331121 8797.641
## 5 -12.2505058 292.28821 1 Argentina Americas 2002 74.34 38331121 8797.641
## 6 -19.1351181 365.11980 1 Argentina Americas 2002 74.34 38331121 8797.641
gapminder::gapminder %>%
filter(continent == "Americas") %>%
filter(year == 2002) %>%
# input must have required aesthetic inputs as columns
rename(area = pop) %>%
compute_panel_circle_pack() %>%
str()
## Joining with `by = join_by(id)`
## 'data.frame': 7525 obs. of 9 variables:
## $ x : num 0 -0.766 -3.064 -6.893 -12.251 ...
## $ y : num 0 73.2 146.3 219.3 292.3 ...
## $ id : int 1 1 1 1 1 1 1 1 1 1 ...
## $ country : Factor w/ 142 levels "Afghanistan",..: 5 5 5 5 5 5 5 5 5 5 ...
## $ continent: Factor w/ 5 levels "Africa","Americas",..: 2 2 2 2 2 2 2 2 2 2 ...
## $ year : int 2002 2002 2002 2002 2002 2002 2002 2002 2002 2002 ...
## $ lifeExp : num 74.3 74.3 74.3 74.3 74.3 ...
## $ area : int 38331121 38331121 38331121 38331121 38331121 38331121 38331121 38331121 38331121 38331121 ...
## $ gdpPercap: num 8798 8798 8798 8798 8798 ...
# step 1b test the computation function
gapminder::gapminder %>%
filter(continent == "Americas") %>%
filter(year == 2002) %>%
# input must have required aesthetic inputs as columns
rename(area = pop) %>%
compute_panel_circle_pack() %>%
ggplot() +
aes(x = x, y = y, fill = country) +
geom_polygon()
## Joining with `by = join_by(id)`
# my_setup_data <- function(data, params){
# if(data$group[1] == -1){
# nrows <- nrow(data)
# data$group <- seq_len(nrows)
# }
# data
# }
StatCirclepack <- ggplot2::ggproto(`_class` = "StatCirclepack",
`_inherit` = ggplot2::Stat,
required_aes = c("id"),
compute_panel = compute_panel_circle_pack,
# setup_data = my_setup_data,
default_aes = aes(group = after_stat(id))
)
geom_polygon_circlepack <- function(mapping = NULL, data = NULL,
position = "identity", na.rm = FALSE,
show.legend = NA,
inherit.aes = TRUE, ...) {
ggplot2::layer(
stat = StatCirclepack, # proto object from Step 2
geom = ggplot2::GeomPolygon, # inherit other behavior
data = data,
mapping = mapping,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = list(na.rm = na.rm, ...)
)
}
gapminder::gapminder %>%
filter(year == 2002) %>%
ggplot() +
aes(id = country) +
geom_polygon_circlepack(alpha = .5, size = .002)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Joining with `by = join_by(id)`
last_plot() +
aes(fill = continent)
## Joining with `by = join_by(id)`
last_plot() +
aes(area = pop)
## Joining with `by = join_by(id)`
last_plot() +
aes(color = continent) +
facet_wrap(facets = vars(continent))
## Joining with `by = join_by(id)`
## Joining with `by = join_by(id)`
## Joining with `by = join_by(id)`
## Joining with `by = join_by(id)`
## Joining with `by = join_by(id)`
Notes:
library(tidyverse)
library(sf)
## Linking to GEOS 3.10.2, GDAL 3.4.2, PROJ 8.2.1; sf_use_s2() is TRUE
fips_geometries <- readRDS(url("https://wilkelab.org/SDS375/datasets/US_counties.rds")) %>%
rename(FIPS = GEOID)
US_census <- read_csv("https://wilkelab.org/SDS375/datasets/US_census.csv",
col_types = cols(FIPS = "c")
)
# from Claus Wilke on ggplot2
fips_geometries %>%
left_join(US_census, by = "FIPS") %>%
ggplot() +
geom_sf(aes(fill = mean_work_travel), linewidth = .1) +
scale_fill_viridis_c(option = "magma") ->
classic_plot_sf_layer
classic_plot_sf_layer
layer_data(classic_plot_sf_layer) %>%
select(geometry, xmin, xmax, ymin, ymax) %>%
bind_cols(tibble(FIPS = fips_geometries$FIPS)) %>%
rename(fips = FIPS) ->
reference
compute_panel_county <- function(data, scales){
data %>%
# inner_join(fips_ggplot2_reference)
inner_join(reference, multiple = "all") %>%
mutate(group = -1)
}
StatCounty <- ggplot2::ggproto(`_class` = "StatCounty",
`_inherit` = ggplot2::Stat,
# required_aes = c("fips"),
# setup_data = my_setup_data,
compute_panel = compute_panel_county,
default_aes = aes(geometry = after_stat(geometry))
)
geom_sf_county <- function(
mapping = NULL,
data = NULL,
position = "identity",
na.rm = FALSE,
show.legend = NA,
inherit.aes = TRUE, ...) {
ggplot2::layer(
stat = StatCounty, # proto object from step 2
geom = ggplot2::GeomSf, # inherit other behavior
data = data,
mapping = mapping,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = list(na.rm = na.rm, ...)
)
}
read_csv("https://wilkelab.org/SDS375/datasets/US_census.csv",
col_types = cols(FIPS = "c")) %>%
ggplot() +
aes(fips = FIPS) +
geom_sf_county(linewidth = .02,
color = "darkgrey") +
aes(fill = mean_work_travel) +
coord_sf() +
scale_fill_viridis_c(option = "magma")
## Joining with `by = join_by(fips)`
We give you step 0
library(sf)
my_states_df <- tibble(state.name, row = 1:50)
# https://www.census.gov/geographies/mapping-files/time-series/geo/carto-boundary-file.html
states_geometries <- read_sf("cb_2018_us_state_20m/cb_2018_us_state_20m.shp") %>%
filter(NAME != "Puerto Rico",
NAME != "District of Colombia") %>%
select(STUSPS, NAME, geometry)
my_states_df %>%
rename(NAME = state.name) %>%
left_join(states_geometries) %>%
ggplot() +
geom_sf(aes(geometry = geometry, state = state.name)) +
coord_sf()
## Joining with `by = join_by(NAME)`
## Warning in layer_sf(geom = GeomSf, data = data, mapping = mapping, stat = stat,
## : Ignoring unknown aesthetics: state
Step 1.
layer_data(last_plot()) %>%
select(state, geometry, xmin, xmax, ymin, ymax) ->
states_reference
compute_panel_states <- function(data, scales){
data %>%
inner_join(states_reference, multiple = "all") %>%
mutate(group = -1)
}
## Step 2.
StatState <- ggplot2::ggproto(`_class` = "StatState",
`_inherit` = ggplot2::Stat,
required_aes = c("state"),
# setup_data = my_setup_data,
compute_panel = compute_panel_states,
default_aes = aes(geometry = after_stat(geometry))
)
## Step 3.
geom_sf_state <- function(
mapping = NULL,
data = NULL,
position = "identity",
na.rm = FALSE,
show.legend = NA,
inherit.aes = TRUE, ...) {
ggplot2::layer(
stat = StatState, # proto object from step 2
geom = ggplot2::GeomSf, # inherit other behavior
data = data,
mapping = mapping,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = list(na.rm = na.rm, ...)
)
}
my_states_df %>%
mutate(state.name = as.factor(state.name)) %>%
ggplot() +
aes(state = state.name) %>%
geom_sf_state() +
aes(fill = row) +
coord_sf()
## Joining with `by = join_by(state)`
`