
cross_trials <- function(trial = prize_wheel, num_trials = 2){

  df <- trial
  names(df) <- paste0("t", 1,"_", names(df))

  if(num_trials > 1){
  for (i in 2:num_trials){

    temp <- trial
    names(temp) <- paste0("t", i,"_", names(trial))

    df <- tidyr::crossing(df, temp)




bernoulli_trial <- function(prob = .25){

  tibble::tibble(outcome = 0:1, prob = c(1-prob, prob))


bernoulli_trial() %>% 
  cross_trials(num_trials = 3)
## # A tibble: 8 × 6
##   t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob
##        <int>   <dbl>      <int>   <dbl>      <int>   <dbl>
## 1          0    0.75          0    0.75          0    0.75
## 2          0    0.75          0    0.75          1    0.25
## 3          0    0.75          1    0.25          0    0.75
## 4          0    0.75          1    0.25          1    0.25
## 5          1    0.25          0    0.75          0    0.75
## 6          1    0.25          0    0.75          1    0.25
## 7          1    0.25          1    0.25          0    0.75
## 8          1    0.25          1    0.25          1    0.25
Trials <- R6Class("Trials",
                       public = list(
                         # objects
                         trial = NULL,
                         index = NULL,
                         out = NULL,

                         # functions
                         init = function(trial = NULL){
                           self$trial <- trial
                           self$index <- 1
                           self$out <- cross_trials(self$trial, num_trials = self$index)
                           invisible(self)          #returns

                         update = function(increment = 1){ # a method
                           self$index <- self$index + increment
                           # displaying
                           self$out <- cross_trials(self$trial, num_trials = self$index)
                           invisible(self)          #returns
                         print = function() {  # print method; default is to print everything


my_trials <- Trials$new() 

my_trials$init(trial = bernoulli_trial())
## # A tibble: 2 × 2
##   t1_outcome t1_prob
##        <int>   <dbl>
## 1          0    0.75
## 2          1    0.25
## # A tibble: 4 × 4
##   t1_outcome t1_prob t2_outcome t2_prob
##        <int>   <dbl>      <int>   <dbl>
## 1          0    0.75          0    0.75
## 2          0    0.75          1    0.25
## 3          1    0.25          0    0.75
## 4          1    0.25          1    0.25
# wrap and pipe
trial_init <- function(trial = NULL
  my_trials <- Trials$new()
  my_trials$init(trial = trial)

bernoulli_trial() %>% trial_init()
## # A tibble: 2 × 2
##   t1_outcome t1_prob
##        <int>   <dbl>
## 1          0    0.75
## 2          1    0.25
trial_advance <- function(trials, increment = 1){
  my_trials <- trials
  my_trials$update(increment = increment)

add_trials <- function(trials, increment = 1){
  if(!is.R6(trials)){my_trials <- trial_init(trial = trials)
  my_trials <- trial_advance(trials = my_trials, 
                                       increment = increment -1)
  if(is.R6(trials)){my_trials <- trial_advance(trials = trials, 
                                       increment = increment)}

bernoulli_trial() %>%
  trial_init() %>% 
  trial_advance() %>% 
## # A tibble: 16 × 8
##    t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob t4_outcome t4_prob
##         <int>   <dbl>      <int>   <dbl>      <int>   <dbl>      <int>   <dbl>
##  1          0    0.75          0    0.75          0    0.75          0    0.75
##  2          0    0.75          0    0.75          0    0.75          1    0.25
##  3          0    0.75          0    0.75          1    0.25          0    0.75
##  4          0    0.75          0    0.75          1    0.25          1    0.25
##  5          0    0.75          1    0.25          0    0.75          0    0.75
##  6          0    0.75          1    0.25          0    0.75          1    0.25
##  7          0    0.75          1    0.25          1    0.25          0    0.75
##  8          0    0.75          1    0.25          1    0.25          1    0.25
##  9          1    0.25          0    0.75          0    0.75          0    0.75
## 10          1    0.25          0    0.75          0    0.75          1    0.25
## 11          1    0.25          0    0.75          1    0.25          0    0.75
## 12          1    0.25          0    0.75          1    0.25          1    0.25
## 13          1    0.25          1    0.25          0    0.75          0    0.75
## 14          1    0.25          1    0.25          0    0.75          1    0.25
## 15          1    0.25          1    0.25          1    0.25          0    0.75
## 16          1    0.25          1    0.25          1    0.25          1    0.25
sum_across <- function(data, var_key = "outcome"){
  dplyr::mutate(.data = data, 
                global_outcome =

seq_across <- function(data, var_key = "outcome"){
  col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
  paste_collapse <- function(x){paste(x, collapse = ", ")}
  data$global_outcome  <- apply(data[,col_list], MARGIN = 1, FUN = paste_collapse)

prod_across <- function(data, var_key = "prob"){
  col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
  data$global_probs  <- apply(data[,col_list], MARGIN = 1, FUN = prod)

bernoulli_trial(prob = .5) %>% 
  add_trials() %>% 
  add_trials() %>% 
  add_trials(5) %>% 
  .$out %>% 
  sum_across() %>% 
  prod_across() %>% 
  group_by(global_outcome) %>% 
  summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
##   global_outcome   probs
##            <dbl>   <dbl>
## 1              0 0.00781
## 2              1 0.0547 
## 3              2 0.164  
## 4              3 0.273  
## 5              4 0.273  
## 6              5 0.164  
## 7              6 0.0547 
## 8              7 0.00781
bernoulli_trial(prob = .5) %>% 
  add_trials(3) %>% 
  .$out %>% 
  seq_across() %>% 
  prod_across() %>% 
  group_by(global_outcome) %>% 
  summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
##   global_outcome probs
##   <chr>          <dbl>
## 1 0, 0, 0        0.125
## 2 0, 0, 1        0.125
## 3 0, 1, 0        0.125
## 4 0, 1, 1        0.125
## 5 1, 0, 0        0.125
## 6 1, 0, 1        0.125
## 7 1, 1, 0        0.125
## 8 1, 1, 1        0.125