library(R6)


cross_trials <- function(trial = prize_wheel, num_trials = 2){

  df <- trial
  names(df) <- paste0("t", 1,"_", names(df))

  if(num_trials > 1){
  for (i in 2:num_trials){

    temp <- trial
    names(temp) <- paste0("t", i,"_", names(trial))

    df <- tidyr::crossing(df, temp)

  }
  }

  df

}


bernoulli_trial <- function(prob = .25){

  tibble::tibble(outcome = 0:1, prob = c(1-prob, prob))

}

library(magrittr)
bernoulli_trial() %>% 
  cross_trials(num_trials = 3)
## # A tibble: 8 × 6
##   t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob
##        <int>   <dbl>      <int>   <dbl>      <int>   <dbl>
## 1          0    0.75          0    0.75          0    0.75
## 2          0    0.75          0    0.75          1    0.25
## 3          0    0.75          1    0.25          0    0.75
## 4          0    0.75          1    0.25          1    0.25
## 5          1    0.25          0    0.75          0    0.75
## 6          1    0.25          0    0.75          1    0.25
## 7          1    0.25          1    0.25          0    0.75
## 8          1    0.25          1    0.25          1    0.25
Trials <- R6Class("Trials",
                       public = list(
                         
                         # objects
                         trial = NULL,
                         index = NULL,
                         out = NULL,
                         

                         # functions
                         init = function(trial = NULL){
                           
                           self$trial <- trial
                           self$index <- 1
                           
                           self$out <- cross_trials(self$trial, num_trials = self$index)
                           
                           invisible(self)          #returns

                           
                         },
                         
                         update = function(increment = 1){ # a method
                           
                           self$index <- self$index + increment
                           
                           # displaying
                           self$out <- cross_trials(self$trial, num_trials = self$index)
                          
                           invisible(self)          #returns
                           
                                     },
                         
                         print = function() {  # print method; default is to print everything

                           print(self$out)

                         }
                       )
)
my_trials <- Trials$new() 

my_trials
## NULL
my_trials$init(trial = bernoulli_trial())
my_trials$out
## # A tibble: 2 × 2
##   t1_outcome t1_prob
##        <int>   <dbl>
## 1          0    0.75
## 2          1    0.25
my_trials$update()
my_trials$out
## # A tibble: 4 × 4
##   t1_outcome t1_prob t2_outcome t2_prob
##        <int>   <dbl>      <int>   <dbl>
## 1          0    0.75          0    0.75
## 2          0    0.75          1    0.25
## 3          1    0.25          0    0.75
## 4          1    0.25          1    0.25
# wrap and pipe
trial_init <- function(trial = NULL
                    ){
  
  my_trials <- Trials$new()
  
  
  my_trials$init(trial = trial)
  
  my_trials
  
}

bernoulli_trial() %>% trial_init()
## # A tibble: 2 × 2
##   t1_outcome t1_prob
##        <int>   <dbl>
## 1          0    0.75
## 2          1    0.25
trial_advance <- function(trials, increment = 1){
  
  my_trials <- trials
  
  my_trials$update(increment = increment)
  
  my_trials
  
}

add_trials <- function(trials, increment = 1){
  
  
  if(!is.R6(trials)){my_trials <- trial_init(trial = trials)
  
  my_trials <- trial_advance(trials = my_trials, 
                                       increment = increment -1)
  
  }
  if(is.R6(trials)){my_trials <- trial_advance(trials = trials, 
                                       increment = increment)}
  
  my_trials
  
}


bernoulli_trial() %>%
  trial_init() %>% 
  trial_advance() %>% 
  trial_advance(2)
## # A tibble: 16 × 8
##    t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob t4_outcome t4_prob
##         <int>   <dbl>      <int>   <dbl>      <int>   <dbl>      <int>   <dbl>
##  1          0    0.75          0    0.75          0    0.75          0    0.75
##  2          0    0.75          0    0.75          0    0.75          1    0.25
##  3          0    0.75          0    0.75          1    0.25          0    0.75
##  4          0    0.75          0    0.75          1    0.25          1    0.25
##  5          0    0.75          1    0.25          0    0.75          0    0.75
##  6          0    0.75          1    0.25          0    0.75          1    0.25
##  7          0    0.75          1    0.25          1    0.25          0    0.75
##  8          0    0.75          1    0.25          1    0.25          1    0.25
##  9          1    0.25          0    0.75          0    0.75          0    0.75
## 10          1    0.25          0    0.75          0    0.75          1    0.25
## 11          1    0.25          0    0.75          1    0.25          0    0.75
## 12          1    0.25          0    0.75          1    0.25          1    0.25
## 13          1    0.25          1    0.25          0    0.75          0    0.75
## 14          1    0.25          1    0.25          0    0.75          1    0.25
## 15          1    0.25          1    0.25          1    0.25          0    0.75
## 16          1    0.25          1    0.25          1    0.25          1    0.25
sum_across <- function(data, var_key = "outcome"){
  
  dplyr::mutate(.data = data, 
                global_outcome =
                  rowSums(dplyr::across(dplyr::contains(var_key)))) 
  
}

seq_across <- function(data, var_key = "outcome"){
  
  col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
  
  paste_collapse <- function(x){paste(x, collapse = ", ")}
  
  data$global_outcome  <- apply(data[,col_list], MARGIN = 1, FUN = paste_collapse)
  
  data
  
}


prod_across <- function(data, var_key = "prob"){
  
  col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
  
  data$global_probs  <- apply(data[,col_list], MARGIN = 1, FUN = prod)
  
  data
  
}

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
bernoulli_trial(prob = .5) %>% 
  add_trials() %>% 
  add_trials() %>% 
  add_trials(5) %>% 
  .$out %>% 
  sum_across() %>% 
  prod_across() %>% 
  group_by(global_outcome) %>% 
  summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
##   global_outcome   probs
##            <dbl>   <dbl>
## 1              0 0.00781
## 2              1 0.0547 
## 3              2 0.164  
## 4              3 0.273  
## 5              4 0.273  
## 6              5 0.164  
## 7              6 0.0547 
## 8              7 0.00781
bernoulli_trial(prob = .5) %>% 
  add_trials(3) %>% 
  .$out %>% 
  seq_across() %>% 
  prod_across() %>% 
  group_by(global_outcome) %>% 
  summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
##   global_outcome probs
##   <chr>          <dbl>
## 1 0, 0, 0        0.125
## 2 0, 0, 1        0.125
## 3 0, 1, 0        0.125
## 4 0, 1, 1        0.125
## 5 1, 0, 0        0.125
## 6 1, 0, 1        0.125
## 7 1, 1, 0        0.125
## 8 1, 1, 1        0.125