library(tidyverse)

vars_pack <- function(...) {
  
  varnames <- as.character(ensyms(...))
  vars <- list(...)
  listvec <- asplit(do.call(cbind, vars), 1)
  structure(listvec, varnames = varnames)

  }

vars_unpack <- function(x) {
  pca_vars <- x
  df <- do.call(rbind, pca_vars)
  colnames(df) <- attr(pca_vars, "varnames")
  as.data.frame(df)
  
}


palmerpenguins::penguins %>% 
  mutate(outcome = species, 
         predictors = vars_pack(bill_length_mm, species, sex)) %>% 
    select(outcome, predictors) ->
data

data
## # A tibble: 344 × 2
##    outcome predictors
##    <fct>   <list[1d]>
##  1 Adelie  <dbl [3]> 
##  2 Adelie  <dbl [3]> 
##  3 Adelie  <dbl [3]> 
##  4 Adelie  <dbl [3]> 
##  5 Adelie  <dbl [3]> 
##  6 Adelie  <dbl [3]> 
##  7 Adelie  <dbl [3]> 
##  8 Adelie  <dbl [3]> 
##  9 Adelie  <dbl [3]> 
## 10 Adelie  <dbl [3]> 
## # ℹ 334 more rows
head(data$predictors)
## [[1]]
## [1] 39.1  1.0  2.0
## 
## [[2]]
## [1] 39.5  1.0  1.0
## 
## [[3]]
## [1] 40.3  1.0  1.0
## 
## [[4]]
## [1] NA  1 NA
## 
## [[5]]
## [1] 36.7  1.0  1.0
## 
## [[6]]
## [1] 39.3  1.0  2.0
data %>%
    mutate(vars_unpack(predictors)) %>% 
  select(-predictors) ->
data

data
## # A tibble: 344 × 4
##    outcome bill_length_mm species   sex
##    <fct>            <dbl>   <dbl> <dbl>
##  1 Adelie            39.1       1     2
##  2 Adelie            39.5       1     1
##  3 Adelie            40.3       1     1
##  4 Adelie            NA         1    NA
##  5 Adelie            36.7       1     1
##  6 Adelie            39.3       1     2
##  7 Adelie            38.9       1     1
##  8 Adelie            39.2       1     2
##  9 Adelie            34.1       1    NA
## 10 Adelie            42         1    NA
## # ℹ 334 more rows
compute_lm_multi <- function(data, drop_x = T, ...){
  
  
    data %>% 
    remove_missing() ->
    data
  
  no_x <- is.null(data$x)
  
  if(is.null(data$x)){data$x <- runif(nrow(data))}
  
  if(!is.null(data$predictors)){

  data %>% 
    select(x, y, predictors) %>% 
    mutate(vars_unpack(predictors)) %>% 
    select(-predictors) ->
  lmdata
    
  } else{
    
    data %>% 
      select(x,y) ->
    lmdata
    
  }
  
  if(drop_x|no_x){
    
  lmdata %>%
      dplyr::select(-x) ->
  lmdata
    
  } 
  
  lm <- lm(data = lmdata, y ~ .)
  
  data$yend = data$y
  data$y = lm$fitted.values
  data$xend = data$x

  data$residuals <- lm$residuals
  
  data
}



palmerpenguins::penguins %>% 
  remove_missing() %>% 
  mutate(y = flipper_length_mm, 
         x = bill_depth_mm, 
         predictors = vars_pack(sex, species)) %>% 
  select(x, y, predictors) %>% 
  select(-x, -predictors) %>% 
  compute_lm_multi(drop_x = T)
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.
## Warning: Unknown or uninitialised column: `x`.
## Unknown or uninitialised column: `x`.
## Warning: Unknown or uninitialised column: `predictors`.
## # A tibble: 333 × 5
##        y     x  yend  xend residuals
##    <dbl> <dbl> <int> <dbl>     <dbl>
##  1  201. 0.929   181 0.929    -20.0 
##  2  201. 0.906   186 0.906    -15.0 
##  3  201. 0.105   195 0.105     -5.97
##  4  201. 0.579   193 0.579     -7.97
##  5  201. 0.825   190 0.825    -11.0 
##  6  201. 0.324   181 0.324    -20.0 
##  7  201. 0.268   195 0.268     -5.97
##  8  201. 0.581   182 0.581    -19.0 
##  9  201. 0.501   191 0.501     -9.97
## 10  201. 0.737   198 0.737     -2.97
## # ℹ 323 more rows
StatLmMulti <- ggproto("StatLmMulti",
                       Stat,
                       compute_panel = compute_lm_multi)

 
palmerpenguins::penguins %>% 
  remove_missing() %>% 
  mutate(x = sample(row_number())) %>% 
  ggplot() + 
  aes(y = flipper_length_mm, x = x) + 
  geom_point() + 
  geom_point(stat = StatLmMulti, alpha = .25, color = "blue") + 
  geom_segment(stat = StatLmMulti, alpha = .25, color = "blue")
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.

layer_data() %>% head()
##     x   y PANEL group shape colour size fill alpha stroke
## 1 170 181     1    -1    19  black  1.5   NA    NA    0.5
## 2 138 186     1    -1    19  black  1.5   NA    NA    0.5
## 3 114 195     1    -1    19  black  1.5   NA    NA    0.5
## 4 274 193     1    -1    19  black  1.5   NA    NA    0.5
## 5 244 190     1    -1    19  black  1.5   NA    NA    0.5
## 6 271 181     1    -1    19  black  1.5   NA    NA    0.5
last_plot() + 
  aes(predictors = vars_pack(species)) + 
  aes(color = species)

last_plot() +
  aes(predictors = vars_pack(species, sex)) + 
  aes(shape = sex)

last_plot() +
  aes(predictors = vars_pack(species, sex, body_mass_g))

last_plot() + 
  aes(predictors = vars_pack(species, sex, body_mass_g, bill_depth_mm))

compute_lm_whatever <- function(data, scales, drop_x = FALSE){
  
  data %>% 
    select(-PANEL) ->
  lmdata
  
  if(drop_x){
    
    lmdata %>% 
      select(-x) ->
    lmdata
    
  }
  
  lm <- lm(data = lmdata, y ~ .)
  
  data$yend = data$y
  data$y = lm$fitted.values
  data$xend = data$x

  data$residuals <- lm$residuals
  
  data
  
}

palmerpenguins::penguins %>% 
  remove_missing() %>% 
  select(y = flipper_length_mm, 
         x = bill_depth_mm,
         pred1 = sex) %>% 
  mutate(PANEL = 1) %>% 
  compute_lm_whatever()
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.
## # A tibble: 333 × 7
##        y     x pred1  PANEL  yend  xend residuals
##    <dbl> <dbl> <fct>  <dbl> <int> <dbl>     <dbl>
##  1  200.  18.7 male       1   181  18.7 -19.0    
##  2  192.  17.4 female     1   186  17.4  -5.94   
##  3  189.  18   female     1   195  18     6.39   
##  4  181.  19.3 female     1   193  19.3  11.6    
##  5  189.  20.6 male       1   190  20.6   0.560  
##  6  190.  17.8 female     1   181  17.8  -8.72   
##  7  195.  19.6 male       1   195  19.6  -0.00157
##  8  191.  17.6 female     1   182  17.6  -8.83   
##  9  186.  21.2 male       1   191  21.2   4.90   
## 10  187.  21.1 male       1   198  21.1  11.3    
## # ℹ 323 more rows
StatLmWhatever <- ggproto("StatLmWhatever",
                       Stat,
                       compute_panel = compute_lm_whatever)
 
palmerpenguins::penguins %>% 
  remove_missing() %>% 
  mutate(x = sample(row_number())) %>% 
  ggplot() + 
  aes(y = flipper_length_mm, x = bill_depth_mm) + 
  geom_point() + 
  geom_point(stat = StatLmWhatever, alpha = .25) + 
  geom_segment(stat = StatLmWhatever, alpha = .25)
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.

layer_data() %>% head()
##      x   y PANEL group shape colour size fill alpha stroke
## 1 18.7 181     1    -1    19  black  1.5   NA    NA    0.5
## 2 17.4 186     1    -1    19  black  1.5   NA    NA    0.5
## 3 18.0 195     1    -1    19  black  1.5   NA    NA    0.5
## 4 19.3 193     1    -1    19  black  1.5   NA    NA    0.5
## 5 20.6 190     1    -1    19  black  1.5   NA    NA    0.5
## 6 17.8 181     1    -1    19  black  1.5   NA    NA    0.5
last_plot() + 
  aes(color = species)

last_plot() +
  aes(shape = sex)

last_plot() +
  aes(p3 = body_mass_g)

last_plot() + 
  aes(p4 = bill_length_mm)

last_plot() + 
  aes(p5 = island)

last_plot() + 
  aes(p6 = year)

compute_square = function(data, scales){
  
  data %>% mutate(y = x, xmax = x, ymax = y, xmin = 0, ymin = 0)
  
}


StatSquare <- ggproto("StatSquare", Stat, compute_group = compute_square)

# residuals
layer_data(i = 2) %>% 
  ggplot() + 
  aes(x = residuals) + 
  geom_rug() +
  scale_x_continuous(limits = c(-40, 40)) + 
  geom_rect(stat = StatSquare, alpha = .2) + 
  coord_equal()

layer_data(i = 2) %>% 
  ggplot() + 
  aes(id = "All", area = x^2) +
  ggcirclepack::geom_circlepack(alpha = .25) + 
  ggcirclepack::geom_circlepack_text() + 
  aes(label = round(after_stat(area))) + 
  labs(title = "Residual Sum of Squares")
## Warning: Unknown or uninitialised column: `wt`.
## Warning: Unknown or uninitialised column: `within`.
## Warning: Unknown or uninitialised column: `wt`.
## Warning: Unknown or uninitialised column: `within`.

# Diff from mean
palmerpenguins::penguins %>% 
  remove_missing() %>% 
  mutate(index = row_number()) %>% 
  ggplot() + 
  aes(y = flipper_length_mm, x = index) + 
  geom_point() + 
  geom_point(stat = StatLmWhatever, alpha = .25, drop_x = T) + 
  geom_segment(stat = StatLmWhatever, alpha = .25, drop_x = T) 
## Warning: Removed 11 rows containing missing values or values outside the scale
## range.

layer_data(i = 2) %>% 
  ggplot() + 
  aes(x = residuals) + 
  geom_rug() + 
  geom_rect(stat = StatSquare, alpha = .2) +
  scale_x_continuous(limits = c(-40, 40)) + 
  coord_equal()

layer_data(i = 2) %>% 
  ggplot() + 
  aes(id = "All", area = x^2) +
  ggcirclepack::geom_circlepack(alpha = .25) + 
  ggcirclepack::geom_circlepack_text() + 
  aes(label = round(after_stat(area))) + 
  labs(title = "Total Sum of Squares")
## Warning: Unknown or uninitialised column: `wt`.
## Unknown or uninitialised column: `within`.
## Warning: Unknown or uninitialised column: `wt`.
## Warning: Unknown or uninitialised column: `within`.

(65219-7567)/66219  # R-squared 0.884
## [1] 0.8706263
lm(flipper_length_mm ~ ., data = palmerpenguins::penguins) %>% 
  summary()
## 
## Call:
## lm(formula = flipper_length_mm ~ ., data = palmerpenguins::penguins)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -14.9290  -2.9850  -0.0741   3.0890  14.3291 
## 
## Coefficients:
##                    Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      -4.616e+03  6.646e+02  -6.947 2.07e-11 ***
## speciesChinstrap  2.850e+00  1.515e+00   1.881  0.06087 .  
## speciesGentoo     2.241e+01  2.261e+00   9.912  < 2e-16 ***
## islandDream       1.706e+00  9.821e-01   1.738  0.08325 .  
## islandTorgersen   2.923e+00  1.017e+00   2.876  0.00430 ** 
## bill_length_mm    2.728e-01  1.205e-01   2.263  0.02428 *  
## bill_depth_mm     1.026e+00  3.379e-01   3.037  0.00258 ** 
## body_mass_g       5.281e-03  8.929e-04   5.915 8.46e-09 ***
## sexmale           7.825e-01  8.858e-01   0.883  0.37768    
## year              2.368e+00  3.309e-01   7.156 5.60e-12 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4.841 on 323 degrees of freedom
##   (11 observations deleted due to missingness)
## Multiple R-squared:  0.884,  Adjusted R-squared:  0.8807 
## F-statistic: 273.4 on 9 and 323 DF,  p-value: < 2.2e-16
knitr::knit_exit()