# draw key for legend
draw_key_rrect <- function(data, params, size) {

  grid::roundrectGrob(
    r = min(params$radius, unit(3, "pt")),
    default.units = "native",
    width = 0.9,
    height = 0.9,
    name = "lkey",
    gp = grid::gpar(
      col = params$color %l0% "white",
      fill = alpha(data$fill %||% data$colour %||% "grey20", data$alpha),
      lty = data$linetype %||% 1
    )
  )
}


# default aes
my_default_aes <- ggplot2::aes(
    colour = "grey",
    fill = "grey35",
    size = 0.5,
    linetype = 1,
    alpha = 1
  )


# calculation
my_draw_panel <- function(
    data,
    panel_scales,
    coord,
    fixed = NULL,
    layout = "squarified",
    start = "bottomleft",
    radius = grid::unit(3, "pt")
  ) {

    data <- coord$transform(data, panel_scales)

    # Generate treemap layout for data
    tparams <- list(
      data = data,
      area = "area",
      fixed = fixed,
      layout = layout,
      start = start
    )
    
    for (x in intersect(c("subgroup", "subgroup2", "subgroup3"), names(data))) {
      tparams[x] <- x
    }
    
    data <- do.call(treemapify, tparams)

    lapply(seq_along(data$xmin), function(i) {

      grid::roundrectGrob(
        x = data$xmin[i],
        width = data$xmax[i] - data$xmin[i],
        y = data$ymax[i],
        height = data$ymax[i] - data$ymin[i],
        default.units = "native",
        r = radius,
        just = c("left", "top"),
        gp = grid::gpar(
          col = data$colour[i],
          fill = ggplot2::alpha(data$fill[i], data$alpha[i]),
          lwd = data$size[i],
          lty = data$linetype[i]
          # lineend = "butt"
        )
      )

    }) -> gl

    grobs <- do.call(grid::gList, gl)

    ggplot2:::ggname("geom_treemap", grid::grobTree(children = grobs))

  }


# ggproto object
GeomTreemap <- ggplot2::ggproto(
  `_class` = "GeomTreemap",
  `_inherit` = ggplot2::Geom,
  required_aes = c("area"),
  default_aes = my_default_aes,
  draw_key = draw_key_rrect,
  draw_panel = my_draw_panel
)



geom_treemap <- function(
  mapping = NULL,
  data = NULL,
  stat = "identity",
  position = "identity",
  na.rm = FALSE,
  show.legend = NA,
  inherit.aes = TRUE,
  fixed = NULL,
  layout = "squarified",
  start = "bottomleft",
  radius = grid::unit(0, "pt"),
  ...
) {
  ggplot2::layer(
    data = data,
    mapping = mapping,
    stat = stat,
    geom = GeomTreemap,
    position = position,
    show.legend = show.legend,
    inherit.aes = inherit.aes,
    params = list(
      na.rm = na.rm,
      fixed = fixed,
      layout = layout,
      start = start,
      radius = radius,
      ...
    )
  )
}

library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.6.2
library(treemapify)
## 
## Attaching package: 'treemapify'
## The following object is masked _by_ '.GlobalEnv':
## 
##     geom_treemap
ggplot(cars) + aes(area = speed, fill = speed) + geom_treemap()