Intro Thoughts

Status Quo

library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.4.1
compute_split <- function(data, scales, seed = getOption("seed", 1999), step = getOption("tt_step", "train")){
  
  set.seed(seed)
  n <- nrow(data)
  n_train = (n*.7) |> round()
  which_train_num <- sample(1:n, size = n_train,  replace = F)
  
  data |> 
    mutate(train_test = ifelse(row_number() %in% which_train_num, "train", "test")) |>
    mutate(ind_step = train_test == step)

  
}

StatSplit <- ggproto("StatSplit", Stat,
                           compute_panel = compute_split,
                           default_aes = aes(alpha = after_stat(ind_step)))

cars |>
  rename(x = speed, y = dist) |>
  compute_split() |>
  head()
##   x  y train_test ind_step
## 1 4  2       test    FALSE
## 2 4 10      train     TRUE
## 3 7  4       test    FALSE
## 4 7 22      train     TRUE
## 5 8 16       test    FALSE
## 6 9 10       test    FALSE
compute_split_lm <- function(data, scales, step = getOption("tt_step", "train")){
  
  data <- data |> compute_split(step = step)

  data$yend = data$y
  data$xend = data$x

  model <- lm(y ~ x, data = data |> filter(train_test == "train"))

  data$y = predict(model, data)

  data
}


StatLmSplit <- ggproto("StatLmSplit", Stat,
                        compute_panel = compute_split_lm,
                        default_aes = aes(alpha = after_stat(ind_step)))

cars |>
  rename(x = speed, y = dist) |>
  compute_split_lm() |>
  head()
##   x         y train_test ind_step yend xend
## 1 4 -2.961164       test    FALSE    2    4
## 2 4 -2.961164      train     TRUE   10    4
## 3 7  9.444378       test    FALSE    4    7
## 4 7  9.444378      train     TRUE   22    7
## 5 8 13.579559       test    FALSE   16    8
## 6 9 17.714740       test    FALSE   10    9
# train
ggplot(cars) + 
  aes(speed, dist) +
  geom_point(stat = StatSplit) + # geom_point_train
  geom_line(stat = StatLmSplit, step = "test", alpha = 1) + 
  geom_point(stat = StatLmSplit) + 
  geom_segment(stat = StatLmSplit, lty = "dashed") + 
  labs(title = "training step") ->
p_test

# test
ggplot(cars) + 
  aes(speed, dist) +
  geom_point(stat = StatSplit, step = "test") + # geom_point_test
  geom_line(stat = StatLmSplit, step = "test", alpha = 1) + 
  geom_segment(stat = StatLmSplit, step = "test", lty = "dashed") + 
  labs(title = "testing step")  ->
p_train


library(patchwork)
## Warning: package 'patchwork' was built under R version 4.4.1
p_test + p_train
## Warning: Using alpha for a discrete variable is not advised.
## Warning: Using alpha for a discrete variable is not advised.

compute_split_lm_se <- function(data, scales, step = getOption("tt_step", "train")){
  
  data <- data |> compute_split_lm(step = step)
  data

  data$residual = data$yend - data$y

  data$x = data$residual/2
  data$y = data$residual/2
  data$width = data$residual
  data$height = data$residual

  data |> select(train_test, residual, x, y, width, height, ind_step)
  
  }


cars |>
  rename(x = speed, y = dist) |>
  compute_split_lm_se() |>
  head()
##   train_test  residual         x         y     width    height ind_step
## 1       test  4.961164  2.480582  2.480582  4.961164  4.961164    FALSE
## 2      train 12.961164  6.480582  6.480582 12.961164 12.961164     TRUE
## 3       test -5.444378 -2.722189 -2.722189 -5.444378 -5.444378    FALSE
## 4      train 12.555622  6.277811  6.277811 12.555622 12.555622     TRUE
## 5       test  2.420441  1.210221  1.210221  2.420441  2.420441    FALSE
## 6       test -7.714740 -3.857370 -3.857370 -7.714740 -7.714740    FALSE
StatLmSplitSE <- ggproto("StatLmSplitSE", Stat,
                        compute_panel = compute_split_lm_se,
                        default_aes = aes(alpha = after_stat(ind_step)))

ggplot(cars) +
  aes(speed, dist) +
  geom_rect(stat = StatLmSplitSE, linetype = "dashed") +
  coord_equal() +
  scale_alpha_discrete(range = c(0,.6)) ->
squared_error_train
## Warning: Using alpha for a discrete variable is not advised.
ggplot(cars) +
  aes(speed, dist) +
  geom_rect(stat = StatLmSplitSE, step = "test", color = "black", linetype = "dashed") +
  coord_equal() +
  scale_alpha_discrete(range = c(0,.6)) ->
squared_error_test
## Warning: Using alpha for a discrete variable is not advised.
squared_error_train + squared_error_test

Experiment

compute_split_lm_rmse <- function(data, scales, step = getOption("tt_step", "train")){
  
  data <- data |> 
    compute_split_lm_se(step = step) |>
    group_by(ind_step, train_test) |>
    summarise(rmse = sqrt(mean(residual^2)))

  data$x = data$rmse/2
  data$y = data$rmse/2
  data$width = data$rmse
  data$height = data$rmse

  data |>
    select(train_test, rmse, x,y, width, height, ind_step)

}

cars |>
  rename(x = speed, y = dist) |>
  compute_split_lm_rmse() |>
  head()
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
## # A tibble: 2 × 7
## # Groups:   ind_step [2]
##   train_test  rmse     x     y width height ind_step
##   <chr>      <dbl> <dbl> <dbl> <dbl>  <dbl> <lgl>   
## 1 test        9.58  4.79  4.79  9.58   9.58 FALSE   
## 2 train      17.0   8.50  8.50 17.0   17.0  TRUE
StatLmSplitRMSE <- ggproto("StatLmSplitRMSE", Stat,
                        compute_panel = compute_split_lm_rmse,
                        default_aes = aes(alpha = after_stat(ind_step)))


ggplot(cars) +
  aes(speed, dist) +
  geom_rect(stat = StatLmSplitSE, linetype = "dashed", color = "black") +
  geom_rect(stat = StatLmSplitRMSE, fill = "blue") +
  coord_equal() +
  scale_alpha(range = c(0,.6)) ->
squared_error_train

ggplot(cars) +
  aes(speed, dist) +
  geom_rect(stat = StatLmSplitSE, step = "test", linetype = "dashed", color = "black") +
  geom_rect(stat = StatLmSplitRMSE, step = "test", fill = "blue") +
  coord_equal() +
  scale_alpha(range = c(0,.6)) ->
squared_error_test

squared_error_train + squared_error_test
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.

layer_data(i = 2)
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
##   PANEL        x        y train_test      rmse     width    height ind_step
## 1     1 8.495426 8.495426      train 16.990851 16.990851 16.990851    FALSE
## 2     1 4.791872 4.791872       test  9.583744  9.583744  9.583744     TRUE
##   alpha xmin      xmax ymin      ymax colour fill linewidth linetype
## 1   0.0    0 16.990851    0 16.990851     NA blue       0.5        1
## 2   0.6    0  9.583744    0  9.583744     NA blue       0.5        1

Closing remarks, Other Relevant Work, Caveats