mtcars %>%
ggplot() +
aes(x = wt, y = mpg, cat = am) +
geom_smooth_cat()
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.0 ✔ readr 2.1.4
## ✔ forcats 1.0.0 ✔ stringr 1.5.0
## ✔ ggplot2 3.4.1 ✔ tibble 3.2.0
## ✔ lubridate 1.9.2 ✔ tidyr 1.3.0
## ✔ purrr 1.0.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
# This function exists to silence an undeclared import warning
gam_method <- function() mgcv::gam
data_frame0 <- function(...) data_frame(..., .name_repair = "minimal")
predictdf <- function(model, xseq, catseq, se, level) UseMethod("predictdf")
#' @export
predictdf.default <- function(model, xseq, catseq, se, level) {
pred <- stats::predict(
model,
newdata = data_frame0(x = xseq, cat = catseq),
se.fit = se,
level = level,
interval = if (se) "confidence" else "none"
)
if (se) {
fit <- as.data.frame(pred$fit)
names(fit) <- c("y", "ymin", "ymax")
base::data.frame(x = xseq, cat = catseq, fit, se = pred$se.fit)
} else {
base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
}
}
#' @export
predictdf.glm <- function(model, xseq, catseq, se, level) {
pred <- stats::predict(
model,
newdata = data_frame0(x = xseq, cat = catseq),
se.fit = se,
type = "link"
)
if (se) {
std <- stats::qnorm(level / 2 + 0.5)
base::data.frame(
x = xseq,
cat = catseq,
y = model$family$linkinv(as.vector(pred$fit)),
ymin = model$family$linkinv(as.vector(pred$fit - std * pred$se.fit)),
ymax = model$family$linkinv(as.vector(pred$fit + std * pred$se.fit)),
se = as.vector(pred$se.fit)
)
} else {
base::data.frame(x = xseq, cat = catseq, y = model$family$linkinv(as.vector(pred)))
}
}
#' @export
predictdf.loess <- function(model, xseq, catseq, se, level) {
pred <- stats::predict(
model,
newdata = data_frame0(x = xseq, cat = catseq),
se = se
)
if (se) {
y <- pred$fit
ci <- pred$se.fit * stats::qt(level / 2 + .5, pred$df)
ymin <- y - ci
ymax <- y + ci
base::data.frame(x = xseq, cat = catseq, y, ymin, ymax, se = pred$se.fit)
} else {
base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
}
}
#' @export
predictdf.locfit <- function(model, xseq, se, level) {
pred <- stats::predict(
model,
newdata = data_frame0(x = xseq),
se.fit = se
)
if (se) {
y <- pred$fit
ci <- pred$se.fit * stats::qt(level / 2 + .5, model$dp["df2"])
ymin <- y - ci
ymax <- y + ci
base::data.frame(x = xseq, cat = catseq, y, ymin, ymax, se = pred$se.fit)
} else {
base::data.frame(x = xseq, cat = catseq, y = as.vector(pred))
}
}
# would be intersted in changeing the x sequence to observed values of x, for drawing fitted and residuals
compute_cat_model_smooth <- function(data, scales, method = NULL, formula = NULL,
se = TRUE, n = 80, span = 0.75, fullrange = FALSE,
xseq = NULL, catseq = NULL,
level = 0.95, method.args = list(),
na.rm = FALSE, flipped_aes = NA) {
data <- flip_data(data, flipped_aes)
if (vctrs::vec_unique_count(data$x) < 2) {
# Not enough data to perform fit
return(data_frame0())
}
if (is.null(data$weight)) data$weight <- 1
the_cats <- unique(data$cat)
num_cats <- length(the_cats)
if(is.null(catseq)){catseq <- sort(rep(the_cats, n))}
if(is.null(xseq)){
xsequences <- list()
xseq <- c()
for(i in 1:num_cats){
subset <- data %>% filter(cat == the_cats[i])
xsequences[[i]] <- seq(min(subset$x),max(subset$x), length.out = n)
xseq <- c(xseq, xsequences[[i]])
}
}
# # which values of x should we predict for ?
# if (is.null(xseq)) {
# if (is.integer(data$x)) {
# if (fullrange) {
# xseq <- scales$x$dimension()
# } else {
# xseq <- sort(unique0(data$x))
# }
# } else {
# if (fullrange) {
# range <- scales$x$dimension()
# } else {
# range <- range(data$x, na.rm = TRUE)
# }
# xseq <- seq(range[1], range[2], length.out = n)
# }
# }
# Special case span because it's the most commonly used model argument
if (identical(method, "loess")) {
method.args$span <- span
}
if (is.character(method)) {
if (identical(method, "gam")) {
method <- gam_method()
} else {
method <- match.fun(method)
}
}
# If gam and gam's method is not specified by the user then use REML
if (identical(method, gam_method()) && is.null(method.args$method)) {
method.args$method <- "REML"
}
model <- rlang::inject(method(
formula,
data = data,
weights = weight,
!!!method.args
))
prediction <- predictdf(model, xseq, catseq, se, level)
prediction$flipped_aes <- flipped_aes
flip_data(prediction, flipped_aes)
}
library(dplyr)
mtcars %>%
rename(x = wt, y = mpg) %>%
mutate(cat = am) %>%
slice(1:5) %>%
compute_cat_model_smooth(method = lm, formula = y ~ x + cat, n = 7)
## Warning: `data_frame()` was deprecated in tibble 1.1.0.
## ℹ Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## x cat y ymin ymax se flipped_aes
## 1 2.3200 0 24.63012 13.717322 35.54292 2.536295 NA
## 2 2.4125 0 24.20961 14.213523 34.20570 2.323239 NA
## 3 2.5050 0 23.78910 14.700712 32.87750 2.112277 NA
## 4 2.5975 0 23.36860 15.175891 31.56130 1.904106 NA
## 5 2.6900 0 22.94809 15.634648 30.26153 1.699752 NA
## 6 2.7825 0 22.52758 16.070270 28.98490 1.500775 NA
## 7 2.8750 0 22.10708 16.472202 27.74195 1.309628 NA
## 8 3.2150 1 18.82693 12.003340 25.65051 1.585902 NA
## 9 3.2525 1 18.65645 11.470073 25.84283 1.670220 NA
## 10 3.2900 1 18.48597 10.934163 26.03778 1.755152 NA
## 11 3.3275 1 18.31550 10.395977 26.23502 1.840614 NA
## 12 3.3650 1 18.14502 9.855817 26.43423 1.926534 NA
## 13 3.4025 1 17.97455 9.313935 26.63516 2.012854 NA
## 14 3.4400 1 17.80407 8.770545 26.83760 2.099525 NA
mtcars %>%
rename(x = wt, y = mpg) %>%
mutate(cat = am) %>%
slice(1:15) %>%
ggplot() +
aes(x, y) +
geom_point(aes(color = factor(cat))) +
geom_point(data = . %>%
compute_cat_model_smooth(method = lm,
formula = y ~ x + cat,
n = 10), color = "blue")
#' @rdname ggplot2-ggproto
#' @format NULL
#' @usage NULL
#' @export
StatSmoothcat <- ggplot2::ggproto("StatSmoothcat", ggplot2::Stat,
setup_params = StatSmooth$setup_params,
extra_params = c("na.rm", "orientation"),
compute_group = StatSmooth$compute_group, #compute_cat_model_smooth,
dropped_aes = c("weight"),
required_aes = c("x", "y")
)
stat_fit <- function(mapping = NULL, data = NULL,
geom = "point", position = "identity",
...,
method = NULL,
formula = NULL,
se = TRUE,
n = 80,
span = 0.75,
fullrange = FALSE,
level = 0.95,
method.args = list(),
na.rm = FALSE,
orientation = NA,
show.legend = NA,
inherit.aes = TRUE) {
layer(
data = data,
mapping = mapping,
stat = StatSmoothcat,
geom = geom,
position = position,
show.legend = show.legend,
inherit.aes = inherit.aes,
params = rlang::list2(
method = method,
formula = formula,
se = se,
n = n,
fullrange = fullrange,
level = level,
na.rm = na.rm,
orientation = orientation,
method.args = method.args,
span = span,
...
)
)
}
New…
mtcars %>%
ggplot() +
aes(x = wt, y = mpg) +
geom_point(aes(color = am)) +
stat_fit(
n = 5,
color = "red")
## Warning in stat_fit(n = 5, color = "red"): Ignoring unknown parameters: `method`, `formula`, `se`, `n`, `fullrange`,
## `level`, `method.args`, and `span`
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## Warning: Computation failed in `stat_smoothcat()`
## Caused by error in `method()`:
## ! could not find function "method"