stat-smooth

library(ggplot2)


stat_smooth <- function(mapping = NULL, data = NULL,
                        geom = "smooth", position = "identity",
                        ...,
                        method = NULL,
                        formula = NULL,
                        se = TRUE,
                        n = 80,
                        span = 0.75,
                        fullrange = FALSE,
                        level = 0.95,
                        method.args = list(),
                        na.rm = FALSE,
                        orientation = NA,
                        show.legend = NA,
                        inherit.aes = TRUE) {
  layer(
    data = data,
    mapping = mapping,
    stat = StatSmooth,
    geom = geom,
    position = position,
    show.legend = show.legend,
    inherit.aes = inherit.aes,
    params = list2(
      method = method,
      formula = formula,
      se = se,
      n = n,
      fullrange = fullrange,
      level = level,
      na.rm = na.rm,
      orientation = orientation,
      method.args = method.args,
      span = span,
      ...
    )
  )
}

# This function exists to silence an undeclared import warning
gam_method <- function() mgcv::gam
data_frame0 <- function(...) data_frame(..., .name_repair = "minimal")

compute_group_smooth <- function(data, scales, method = NULL, formula = NULL,
                           se = TRUE, n = 80, span = 0.75, fullrange = FALSE,
                           xseq = NULL, level = 0.95, method.args = list(),
                           na.rm = FALSE, flipped_aes = NA) {
    data <- flip_data(data, flipped_aes)
    if (vctrs::vec_unique_count(data$x) < 2) {
      # Not enough data to perform fit
      return(data_frame0())
    }

    if (is.null(data$weight)) data$weight <- 1

    if (is.null(xseq)) {
      if (is.integer(data$x)) {
        if (fullrange) {
          xseq <- scales$x$dimension()
        } else {
          xseq <- sort(unique0(data$x))
        }
      } else {
        if (fullrange) {
          range <- scales$x$dimension()
        } else {
          range <- range(data$x, na.rm = TRUE)
        }
        xseq <- seq(range[1], range[2], length.out = n)
      }
    }

    # Special case span because it's the most commonly used model argument
    if (identical(method, "loess")) {
      method.args$span <- span
    }

    if (is.character(method)) {
      if (identical(method, "gam")) {
        method <- gam_method()
      } else {
        method <- match.fun(method)
      }
    }
    # If gam and gam's method is not specified by the user then use REML
    if (identical(method, gam_method()) && is.null(method.args$method)) {
      method.args$method <- "REML"
    }

    model <- rlang::inject(method(
      formula,
      data = data,
      weights = weight,
      !!!method.args
    ))

    prediction <- predictdf(model, xseq, se, level)
    prediction$flipped_aes <- flipped_aes
    flip_data(prediction, flipped_aes)
}

predictdf <- function(model, xseq, se, level) UseMethod("predictdf")

#' @export
predictdf.default <- function(model, xseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq),
    se.fit = se,
    level = level,
    interval = if (se) "confidence" else "none"
  )

  if (se) {
    fit <- as.data.frame(pred$fit)
    names(fit) <- c("y", "ymin", "ymax")
    base::data.frame(x = xseq, fit, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, y = as.vector(pred))
  }
}

#' @export
predictdf.glm <- function(model, xseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq),
    se.fit = se,
    type = "link"
  )

  if (se) {
    std <- stats::qnorm(level / 2 + 0.5)
    base::data.frame(
      x = xseq,
      y = model$family$linkinv(as.vector(pred$fit)),
      ymin = model$family$linkinv(as.vector(pred$fit - std * pred$se.fit)),
      ymax = model$family$linkinv(as.vector(pred$fit + std * pred$se.fit)),
      se = as.vector(pred$se.fit)
    )
  } else {
    base::data.frame(x = xseq, y = model$family$linkinv(as.vector(pred)))
  }
}

#' @export
predictdf.loess <- function(model, xseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq),
    se = se
  )

  if (se) {
    y <- pred$fit
    ci <- pred$se.fit * stats::qt(level / 2 + .5, pred$df)
    ymin <- y - ci
    ymax <- y + ci
    base::data.frame(x = xseq, y, ymin, ymax, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, y = as.vector(pred))
  }
}

#' @export
predictdf.locfit <- function(model, xseq, se, level) {
  pred <- stats::predict(
    model,
    newdata = data_frame0(x = xseq),
    se.fit = se
  )

  if (se) {
    y <- pred$fit
    ci <- pred$se.fit * stats::qt(level / 2 + .5, model$dp["df2"])
    ymin <- y - ci
    ymax <- y + ci
    base::data.frame(x = xseq, y, ymin, ymax, se = pred$se.fit)
  } else {
    base::data.frame(x = xseq, y = as.vector(pred))
  }
}

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
mtcars %>% 
  rename(x = wt, y = mpg) %>% 
  compute_group_smooth(method = lm, formula = y ~ x)
## Warning: `data_frame()` was deprecated in tibble 1.1.0.
## ℹ Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
##           x         y      ymin     ymax        se flipped_aes
## 1  1.513000 29.198941 26.963760 31.43412 1.0944578          NA
## 2  1.562506 28.934356 26.748212 31.12050 1.0704467          NA
## 3  1.612013 28.669770 26.532294 30.80725 1.0466168          NA
## 4  1.661519 28.405185 26.315980 30.49439 1.0229807          NA
## 5  1.711025 28.140600 26.099242 30.18196 0.9995522          NA
## 6  1.760532 27.876015 25.882050 29.86998 0.9763462          NA
## 7  1.810038 27.611430 25.664370 29.55849 0.9533790          NA
## 8  1.859544 27.346844 25.446166 29.24752 0.9306683          NA
## 9  1.909051 27.082259 25.227400 28.93712 0.9082332          NA
## 10 1.958557 26.817674 25.008027 28.62732 0.8860947          NA
## 11 2.008063 26.553089 24.788003 28.31818 0.8642757          NA
## 12 2.057570 26.288504 24.567275 28.00973 0.8428009          NA
## 13 2.107076 26.023919 24.345789 27.70205 0.8216973          NA
## 14 2.156582 25.759333 24.123485 27.39518 0.8009943          NA
## 15 2.206089 25.494748 23.900298 27.08920 0.7807237          NA
## 16 2.255595 25.230163 23.676157 26.78417 0.7609202          NA
## 17 2.305101 24.965578 23.450986 26.48017 0.7416210          NA
## 18 2.354608 24.700993 23.224702 26.17728 0.7228666          NA
## 19 2.404114 24.436408 22.997217 25.87560 0.7047005          NA
## 20 2.453620 24.171822 22.768435 25.57521 0.6871694          NA
## 21 2.503127 23.907237 22.538255 25.27622 0.6703230          NA
## 22 2.552633 23.642652 22.306568 24.97874 0.6542143          NA
## 23 2.602139 23.378067 22.073261 24.68287 0.6388992          NA
## 24 2.651646 23.113482 21.838214 24.38875 0.6244358          NA
## 25 2.701152 22.848897 21.601303 24.09649 0.6108849          NA
## 26 2.750658 22.584311 21.362403 23.80622 0.5983083          NA
## 27 2.800165 22.319726 21.121385 23.51807 0.5867688          NA
## 28 2.849671 22.055141 20.878121 23.23216 0.5763286          NA
## 29 2.899177 21.790556 20.632489 22.94862 0.5670484          NA
## 30 2.948684 21.525971 20.384369 22.66757 0.5589861          NA
## 31 2.998190 21.261386 20.133653 22.38912 0.5521949          NA
## 32 3.047696 20.996800 19.880245 22.11336 0.5467223          NA
## 33 3.097203 20.732215 19.624062 21.84037 0.5426081          NA
## 34 3.146709 20.467630 19.365041 21.57022 0.5398835          NA
## 35 3.196215 20.203045 19.103139 21.30295 0.5385694          NA
## 36 3.245722 19.938460 18.838336 21.03858 0.5386762          NA
## 37 3.295228 19.673875 18.570633 20.77712 0.5402031          NA
## 38 3.344734 19.409289 18.300053 20.51853 0.5431381          NA
## 39 3.394241 19.144704 18.026645 20.26276 0.5474586          NA
## 40 3.443747 18.880119 17.750473 20.00977 0.5531320          NA
## 41 3.493253 18.615534 17.471622 19.75945 0.5601173          NA
## 42 3.542759 18.350949 17.190190 19.51171 0.5683661          NA
## 43 3.592266 18.086364 16.906289 19.26644 0.5778243          NA
## 44 3.641772 17.821778 16.620037 19.02352 0.5884336          NA
## 45 3.691278 17.557193 16.331558 18.78283 0.6001329          NA
## 46 3.740785 17.292608 16.040981 18.54423 0.6128598          NA
## 47 3.790291 17.028023 15.748434 18.30761 0.6265517          NA
## 48 3.839797 16.763438 15.454041 18.07283 0.6411468          NA
## 49 3.889304 16.498853 15.157927 17.83978 0.6565849          NA
## 50 3.938810 16.234267 14.860210 17.60832 0.6728079          NA
## 51 3.988316 15.969682 14.561004 17.37836 0.6897604          NA
## 52 4.037823 15.705097 14.260414 17.14978 0.7073900          NA
## 53 4.087329 15.440512 13.958542 16.92248 0.7256474          NA
## 54 4.136835 15.175927 13.655483 16.69637 0.7444863          NA
## 55 4.186342 14.911342 13.351324 16.47136 0.7638638          NA
## 56 4.235848 14.646756 13.046146 16.24737 0.7837398          NA
## 57 4.285354 14.382171 12.740026 16.02432 0.8040775          NA
## 58 4.334861 14.117586 12.433033 15.80214 0.8248426          NA
## 59 4.384367 13.853001 12.125231 15.58077 0.8460038          NA
## 60 4.433873 13.588416 11.816679 15.36015 0.8675320          NA
## 61 4.483380 13.323831 11.507432 15.14023 0.8894005          NA
## 62 4.532886 13.059245 11.197541 14.92095 0.9115849          NA
## 63 4.582392 12.794660 10.887050 14.70227 0.9340627          NA
## 64 4.631899 12.530075 10.576002 14.48415 0.9568132          NA
## 65 4.681405 12.265490 10.264436 14.26654 0.9798174          NA
## 66 4.730911 12.000905  9.952387 14.04942 1.0030578          NA
## 67 4.780418 11.736320  9.639889 13.83275 1.0265184          NA
## 68 4.829924 11.471734  9.326972 13.61650 1.0501844          NA
## 69 4.879430 11.207149  9.013662 13.40064 1.0740423          NA
## 70 4.928937 10.942564  8.699986 13.18514 1.0980796          NA
## 71 4.978443 10.677979  8.385968 12.96999 1.1222847          NA
## 72 5.027949 10.413394  8.071628 12.75516 1.1466470          NA
## 73 5.077456 10.148809  7.756987 12.54063 1.1711567          NA
## 74 5.126962  9.884223  7.442064 12.32638 1.1958047          NA
## 75 5.176468  9.619638  7.126876 12.11240 1.2205827          NA
## 76 5.225975  9.355053  6.811438 11.89867 1.2454828          NA
## 77 5.275481  9.090468  6.495765 11.68517 1.2704980          NA
## 78 5.324987  8.825883  6.179871 11.47189 1.2956215          NA
## 79 5.374494  8.561298  5.863768 11.25883 1.3208472          NA
## 80 5.424000  8.296712  5.547468 11.04596 1.3461693          NA
#' @rdname ggplot2-ggproto
#' @format NULL
#' @usage NULL
#' @export
StatSmooth <- ggplot2::ggproto("StatSmooth", ggplot2::Stat,
  setup_params = function(data, params) {
    params$flipped_aes <- has_flipped_aes(data, params, ambiguous = TRUE)
    msg <- character()
    if (is.null(params$method) || identical(params$method, "auto")) {
      # Use loess for small datasets, gam with a cubic regression basis for
      # larger. Based on size of the _largest_ group to avoid bad memory
      # behaviour of loess
      max_group <- max(table(interaction(data$group, data$PANEL, drop = TRUE)))

      if (max_group < 1000) {
        params$method <- "loess"
      } else {
        params$method <- "gam"
      }
      msg <- c(msg, paste0("method = '", params$method, "'"))
    }

    if (is.null(params$formula)) {
      if (identical(params$method, "gam")) {
        params$formula <- y ~ s(x, bs = "cs")
      } else {
        params$formula <- y ~ x
      }
      msg <- c(msg, paste0("formula = '", deparse(params$formula), "'"))
    }
    if (identical(params$method, "gam")) {
      params$method <- gam_method()
    }

    if (length(msg) > 0) {
      cli::cli_inform("{.fn geom_smooth} using {msg}")
    }

    params
  },
  extra_params = c("na.rm", "orientation"),
  compute_group = compute_group_smooth,
  dropped_aes = c("weight"),
  required_aes = c("x", "y")
)