Status Quo
library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.4.1
compute_split <- function(data, scales, seed = getOption("seed", 1999), step = getOption("tt_step", "train")){
set.seed(seed)
n <- nrow(data)
n_train = (n*.7) |> round()
which_train_num <- sample(1:n, size = n_train, replace = F)
data |>
mutate(train_test = ifelse(row_number() %in% which_train_num, "train", "test")) |>
mutate(ind_step = train_test == step)
}
StatSplit <- ggproto("StatSplit", Stat,
compute_panel = compute_split,
default_aes = aes(alpha = after_stat(ind_step)))
cars |>
rename(x = speed, y = dist) |>
compute_split() |>
head()
## x y train_test ind_step
## 1 4 2 test FALSE
## 2 4 10 train TRUE
## 3 7 4 test FALSE
## 4 7 22 train TRUE
## 5 8 16 test FALSE
## 6 9 10 test FALSE
compute_split_lm <- function(data, scales, step = getOption("tt_step", "train")){
data <- data |> compute_split(step = step)
data$yend = data$y
data$xend = data$x
model <- lm(y ~ x, data = data |> filter(train_test == "train"))
data$y = predict(model, data)
data
}
StatLmSplit <- ggproto("StatLmSplit", Stat,
compute_panel = compute_split_lm,
default_aes = aes(alpha = after_stat(ind_step)))
cars |>
rename(x = speed, y = dist) |>
compute_split_lm() |>
head()
## x y train_test ind_step yend xend
## 1 4 -2.961164 test FALSE 2 4
## 2 4 -2.961164 train TRUE 10 4
## 3 7 9.444378 test FALSE 4 7
## 4 7 9.444378 train TRUE 22 7
## 5 8 13.579559 test FALSE 16 8
## 6 9 17.714740 test FALSE 10 9
# train
ggplot(cars) +
aes(speed, dist) +
geom_point(stat = StatSplit) + # geom_point_train
geom_line(stat = StatLmSplit, step = "test", alpha = 1) +
geom_point(stat = StatLmSplit) +
geom_segment(stat = StatLmSplit, lty = "dashed") +
labs(title = "training step") ->
p_test
# test
ggplot(cars) +
aes(speed, dist) +
geom_point(stat = StatSplit, step = "test") + # geom_point_test
geom_line(stat = StatLmSplit, step = "test", alpha = 1) +
geom_segment(stat = StatLmSplit, step = "test", lty = "dashed") +
labs(title = "testing step") ->
p_train
library(patchwork)
## Warning: package 'patchwork' was built under R version 4.4.1
p_test + p_train
## Warning: Using alpha for a discrete variable is not advised.
## Warning: Using alpha for a discrete variable is not advised.

compute_split_lm_se <- function(data, scales, step = getOption("tt_step", "train")){
data <- data |> compute_split_lm(step = step)
data
data$residual = data$yend - data$y
data$x = data$residual/2
data$y = data$residual/2
data$width = data$residual
data$height = data$residual
data |> select(train_test, residual, x, y, width, height, ind_step)
}
cars |>
rename(x = speed, y = dist) |>
compute_split_lm_se() |>
head()
## train_test residual x y width height ind_step
## 1 test 4.961164 2.480582 2.480582 4.961164 4.961164 FALSE
## 2 train 12.961164 6.480582 6.480582 12.961164 12.961164 TRUE
## 3 test -5.444378 -2.722189 -2.722189 -5.444378 -5.444378 FALSE
## 4 train 12.555622 6.277811 6.277811 12.555622 12.555622 TRUE
## 5 test 2.420441 1.210221 1.210221 2.420441 2.420441 FALSE
## 6 test -7.714740 -3.857370 -3.857370 -7.714740 -7.714740 FALSE
StatLmSplitSE <- ggproto("StatLmSplitSE", Stat,
compute_panel = compute_split_lm_se,
default_aes = aes(alpha = after_stat(ind_step)))
ggplot(cars) +
aes(speed, dist) +
geom_rect(stat = StatLmSplitSE, linetype = "dashed") +
coord_equal() +
scale_alpha_discrete(range = c(0,.6)) ->
squared_error_train
## Warning: Using alpha for a discrete variable is not advised.
ggplot(cars) +
aes(speed, dist) +
geom_rect(stat = StatLmSplitSE, step = "test", color = "black", linetype = "dashed") +
coord_equal() +
scale_alpha_discrete(range = c(0,.6)) ->
squared_error_test
## Warning: Using alpha for a discrete variable is not advised.
squared_error_train + squared_error_test

Experiment
compute_split_lm_rmse <- function(data, scales, step = getOption("tt_step", "train")){
data <- data |>
compute_split_lm_se(step = step) |>
group_by(ind_step, train_test) |>
summarise(rmse = sqrt(mean(residual^2)))
data$x = data$rmse/2
data$y = data$rmse/2
data$width = data$rmse
data$height = data$rmse
data |>
select(train_test, rmse, x,y, width, height, ind_step)
}
cars |>
rename(x = speed, y = dist) |>
compute_split_lm_rmse() |>
head()
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
## # A tibble: 2 × 7
## # Groups: ind_step [2]
## train_test rmse x y width height ind_step
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <lgl>
## 1 test 9.58 4.79 4.79 9.58 9.58 FALSE
## 2 train 17.0 8.50 8.50 17.0 17.0 TRUE
StatLmSplitRMSE <- ggproto("StatLmSplitRMSE", Stat,
compute_panel = compute_split_lm_rmse,
default_aes = aes(alpha = after_stat(ind_step)))
ggplot(cars) +
aes(speed, dist) +
geom_rect(stat = StatLmSplitSE, linetype = "dashed", color = "black") +
geom_rect(stat = StatLmSplitRMSE, fill = "blue") +
coord_equal() +
scale_alpha(range = c(0,.6)) ->
squared_error_train
ggplot(cars) +
aes(speed, dist) +
geom_rect(stat = StatLmSplitSE, step = "test", linetype = "dashed", color = "black") +
geom_rect(stat = StatLmSplitRMSE, step = "test", fill = "blue") +
coord_equal() +
scale_alpha(range = c(0,.6)) ->
squared_error_test
squared_error_train + squared_error_test
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.

layer_data(i = 2)
## `summarise()` has grouped output by 'ind_step'. You can override using the
## `.groups` argument.
## PANEL x y train_test rmse width height ind_step
## 1 1 8.495426 8.495426 train 16.990851 16.990851 16.990851 FALSE
## 2 1 4.791872 4.791872 test 9.583744 9.583744 9.583744 TRUE
## alpha xmin xmax ymin ymax colour fill linewidth linetype
## 1 0.0 0 16.990851 0 16.990851 NA blue 0.5 1
## 2 0.6 0 9.583744 0 9.583744 NA blue 0.5 1