Status Quo
library(tidyverse)
## Warning: package 'ggplot2' was built under R version 4.4.1
compute_train_test <- function(data = cars, scales, seed = getOption("seed", 1999), step = getOption("tt_step", "train")){
set.seed(seed)
n <- nrow(data)
n_train = n/2 |> round()
which_train_num <- sample(1:n, size = n_train, replace = F)
which_train <- 1:n %in% which_train_num
which_test <- !which_train
if(step == "train"){
data |> _[which_train, ]
}else{
data |> _[which_test, ]
}
}
compute_train_test_lm <- function(data, scales, step = getOption("tt_step", "train")){
train_data <- data |> compute_train_test(step = "train")
split_data <- data |> compute_train_test(step = step)
model <- lm(y ~ x, data = train_data)
split_data$y = predict(model, split_data)
split_data
}
cars |>
rename(x = speed, y = dist) |>
compute_train_test() |>
compute_train_test_lm()
## x y
## 10 11 19.73009
## 15 12 24.27865
## 22 14 33.37576
## 24 15 37.92432
## 29 17 47.02143
## 30 17 47.02143
## 34 18 51.56999
## 37 19 56.11855
## 38 19 56.11855
## 46 24 78.86134
## 48 24 78.86134
## 50 25 83.40990
library(statexpress)
# train
ggplot(cars) +
aes(speed, dist) +
geom_point(stat = qstat(compute_train_test)) +
geom_point(stat = qstat(compute_train_test_lm), color = "blue")

# test
options(tt_step = "test")
ggplot(cars) +
aes(speed, dist) +
geom_point(stat = qstat(compute_train_test)) +
geom_point(stat = qstat(compute_train_test_lm),
color = "green")

# train
options(tt_step = "train")
test <- ggplot(mpg) +
aes(cty, hwy) +
geom_point(stat = qstat(compute_train_test)) +
geom_point(stat = qstat(compute_train_test_lm), color = "blue")
# test
options(tt_step = "test")
train <- ggplot(mpg) +
aes(cty, hwy) +
geom_point(stat = qstat(compute_train_test)) +
geom_point(stat = qstat(compute_train_test_lm),
color = "green")
## the lesson of why not to use options! 😬
library(patchwork)
## Warning: package 'patchwork' was built under R version 4.4.1
test + train
