Proposal: Can we allow categorical values in model?

Target code

mtcars %>% 
  ggplot() + 
  aes(x = wt, y = mpg, cat = am) + 
  geom_smooth_cat()
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.0     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.1     ✔ tibble    3.2.0
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.1     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
# This function exists to silence an undeclared import warning
gam_method <- function() mgcv::gam
data_frame0 <- function(...) data_frame(..., .name_repair = "minimal")
predictdf <- function(model, xseq, catseq, se, level) UseMethod("predictdf")

prediction mechanisms

#' @export
predictdf.default <- function(model, xseq, catseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq, cat = catseq),
    se.fit = se,
    level = level,
    interval = if (se) "confidence" else "none"
  )

  if (se) {
    fit <- as.data.frame(pred$fit)
    names(fit) <- c("y", "ymin", "ymax")
    base::data.frame(x = xseq, cat = catseq, fit, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
  }
}

#' @export
predictdf.glm <- function(model, xseq, catseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq, cat = catseq),
    se.fit = se,
    type = "link"
  )

  if (se) {
    std <- stats::qnorm(level / 2 + 0.5)
    base::data.frame(
      x = xseq,
      cat = catseq,
      y = model$family$linkinv(as.vector(pred$fit)),
      ymin = model$family$linkinv(as.vector(pred$fit - std * pred$se.fit)),
      ymax = model$family$linkinv(as.vector(pred$fit + std * pred$se.fit)),
      se = as.vector(pred$se.fit)
    )
  } else {
    base::data.frame(x = xseq, cat = catseq, y = model$family$linkinv(as.vector(pred)))
  }
}

#' @export
predictdf.loess <- function(model, xseq, catseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq, cat = catseq),
    se = se
  )

  if (se) {
    y <- pred$fit
    ci <- pred$se.fit * stats::qt(level / 2 + .5, pred$df)
    ymin <- y - ci
    ymax <- y + ci
    base::data.frame(x = xseq, cat = catseq, y, ymin, ymax, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
  }
}

#' @export
predictdf.locfit <- function(model, xseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq),
    se.fit = se
  )

  if (se) {
    y <- pred$fit
    ci <- pred$se.fit * stats::qt(level / 2 + .5, model$dp["df2"])
    ymin <- y - ci
    ymax <- y + ci
    base::data.frame(x = xseq, cat = catseq, y, ymin, ymax, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
  }
}

compute_group_smooth

# would be intersted in changeing the x sequence to observed values of x, for drawing fitted and residuals

compute_cat_model_smooth <- function(data, scales, method = NULL, formula = NULL,
                           se = TRUE, n = 80, span = 0.75, fullrange = FALSE,
                           xseq = NULL, catseq = NULL,
                           level = 0.95, method.args = list(),
                           na.rm = FALSE, flipped_aes = NA) {
    data <- flip_data(data, flipped_aes)
    if (vctrs::vec_unique_count(data$x) < 2) {
      # Not enough data to perform fit
      return(data_frame0())
    }

    if (is.null(data$weight)) data$weight <- 1

    the_cats <- unique(data$cat)
    num_cats <- length(the_cats)
    
    if(is.null(catseq)){catseq <- sort(rep(the_cats, n))}
    if(is.null(xseq)){
      
      xsequences <- list() 
      xseq <- c()
      
      for(i in 1:num_cats){
        
        subset <- data %>% filter(cat == the_cats[i])
        
        xsequences[[i]] <- seq(min(subset$x),max(subset$x), length.out = n)
        
        xseq <- c(xseq, xsequences[[i]])

      }
      
      
    }
    
    
    # # which values of x should we predict for ? 
    # if (is.null(xseq)) {
    #   if (is.integer(data$x)) {
    #     if (fullrange) {
    #       xseq <- scales$x$dimension()
    #     } else {
    #       xseq <- sort(unique0(data$x))
    #     }
    #   } else {
    #     if (fullrange) {
    #       range <- scales$x$dimension()
    #     } else {
    #       range <- range(data$x, na.rm = TRUE)
    #     }
    #     xseq <- seq(range[1], range[2], length.out = n)
    #   }
    # }

    # Special case span because it's the most commonly used model argument
    if (identical(method, "loess")) {
      method.args$span <- span
    }

    if (is.character(method)) {
      if (identical(method, "gam")) {
        method <- gam_method()
      } else {
        method <- match.fun(method)
      }
    }
    # If gam and gam's method is not specified by the user then use REML
    if (identical(method, gam_method()) && is.null(method.args$method)) {
      method.args$method <- "REML"
    }

    model <- rlang::inject(method(
      formula,
      data = data,
      weights = weight,
      !!!method.args
    ))

    prediction <- predictdf(model, xseq, catseq, se, level)
    prediction$flipped_aes <- flipped_aes
    flip_data(prediction, flipped_aes)
}

test compute group

library(dplyr)
mtcars %>%
  rename(x = wt, y = mpg) %>% 
  mutate(cat = am) %>%
  slice(1:5) %>% 
  compute_cat_model_smooth(method = lm, formula = y ~ x + cat, n = 7)
## Warning: `data_frame()` was deprecated in tibble 1.1.0.
## ℹ Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
##         x cat        y      ymin     ymax       se flipped_aes
## 1  2.3200   0 24.63012 13.717322 35.54292 2.536295          NA
## 2  2.4125   0 24.20961 14.213523 34.20570 2.323239          NA
## 3  2.5050   0 23.78910 14.700712 32.87750 2.112277          NA
## 4  2.5975   0 23.36860 15.175891 31.56130 1.904106          NA
## 5  2.6900   0 22.94809 15.634648 30.26153 1.699752          NA
## 6  2.7825   0 22.52758 16.070270 28.98490 1.500775          NA
## 7  2.8750   0 22.10708 16.472202 27.74195 1.309628          NA
## 8  3.2150   1 18.82693 12.003340 25.65051 1.585902          NA
## 9  3.2525   1 18.65645 11.470073 25.84283 1.670220          NA
## 10 3.2900   1 18.48597 10.934163 26.03778 1.755152          NA
## 11 3.3275   1 18.31550 10.395977 26.23502 1.840614          NA
## 12 3.3650   1 18.14502  9.855817 26.43423 1.926534          NA
## 13 3.4025   1 17.97455  9.313935 26.63516 2.012854          NA
## 14 3.4400   1 17.80407  8.770545 26.83760 2.099525          NA
mtcars %>%
  rename(x = wt, y = mpg) %>% 
  mutate(cat = am) %>%
  slice(1:15) %>% 
  ggplot() + 
  aes(x, y) + 
  geom_point(aes(color = factor(cat))) + 
  geom_point(data = . %>% 
            compute_cat_model_smooth(method = lm, 
                                    formula = y ~ x + cat, 
                                    n = 10), color = "blue")

#' @rdname ggplot2-ggproto
#' @format NULL
#' @usage NULL
#' @export
StatSmoothcat <- ggplot2::ggproto("StatSmoothcat", ggplot2::Stat,
  setup_params = StatSmooth$setup_params,
  extra_params = c("na.rm", "orientation"),
  compute_group = StatSmooth$compute_group,  #compute_cat_model_smooth,
  dropped_aes = c("weight"),
  required_aes = c("x", "y")
)



stat_fit <- function(mapping = NULL, data = NULL,
            geom = "point", position = "identity",
            ...,
            method = NULL,
            formula = NULL,
            se = TRUE,
            n = 80,
            span = 0.75,
            fullrange = FALSE,
            level = 0.95,
            method.args = list(),
            na.rm = FALSE,
            orientation = NA,
            show.legend = NA,
            inherit.aes = TRUE) {
  layer(
    data = data,
    mapping = mapping,
    stat = StatSmoothcat,
    geom = geom,
    position = position,
    show.legend = show.legend,
    inherit.aes = inherit.aes,
    params = rlang::list2(
      method = method,
      formula = formula,
      se = se,
      n = n,
      fullrange = fullrange,
      level = level,
      na.rm = na.rm,
      orientation = orientation,
      method.args = method.args,
      span = span,
      ...
    )
  )
}

New…

mtcars %>% 
  ggplot() +
  aes(x = wt, y = mpg) + 
  geom_point(aes(color = am)) +
  stat_fit(
                  n = 5,
                  color = "red")
## Warning in stat_fit(n = 5, color = "red"): Ignoring unknown parameters: `method`, `formula`, `se`, `n`, `fullrange`,
## `level`, `method.args`, and `span`
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## Warning: Computation failed in `stat_smoothcat()`
## Caused by error in `method()`:
## ! could not find function "method"