library(R6)
cross_trials <- function(trial = prize_wheel, num_trials = 2){
df <- trial
names(df) <- paste0("t", 1,"_", names(df))
if(num_trials > 1){
for (i in 2:num_trials){
temp <- trial
names(temp) <- paste0("t", i,"_", names(trial))
df <- tidyr::crossing(df, temp)
}
}
df
}
bernoulli_trial <- function(prob = .25){
tibble::tibble(outcome = 0:1, prob = c(1-prob, prob))
}
library(magrittr)
bernoulli_trial() %>%
cross_trials(num_trials = 3)
## # A tibble: 8 × 6
## t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob
## <int> <dbl> <int> <dbl> <int> <dbl>
## 1 0 0.75 0 0.75 0 0.75
## 2 0 0.75 0 0.75 1 0.25
## 3 0 0.75 1 0.25 0 0.75
## 4 0 0.75 1 0.25 1 0.25
## 5 1 0.25 0 0.75 0 0.75
## 6 1 0.25 0 0.75 1 0.25
## 7 1 0.25 1 0.25 0 0.75
## 8 1 0.25 1 0.25 1 0.25
Trials <- R6Class("Trials",
public = list(
# objects
trial = NULL,
index = NULL,
out = NULL,
# functions
init = function(trial = NULL){
self$trial <- trial
self$index <- 1
self$out <- cross_trials(self$trial, num_trials = self$index)
invisible(self) #returns
},
update = function(increment = 1){ # a method
self$index <- self$index + increment
# displaying
self$out <- cross_trials(self$trial, num_trials = self$index)
invisible(self) #returns
},
print = function() { # print method; default is to print everything
print(self$out)
}
)
)
my_trials <- Trials$new()
my_trials
## NULL
my_trials$init(trial = bernoulli_trial())
my_trials$out
## # A tibble: 2 × 2
## t1_outcome t1_prob
## <int> <dbl>
## 1 0 0.75
## 2 1 0.25
my_trials$update()
my_trials$out
## # A tibble: 4 × 4
## t1_outcome t1_prob t2_outcome t2_prob
## <int> <dbl> <int> <dbl>
## 1 0 0.75 0 0.75
## 2 0 0.75 1 0.25
## 3 1 0.25 0 0.75
## 4 1 0.25 1 0.25
# wrap and pipe
trial_init <- function(trial = NULL
){
my_trials <- Trials$new()
my_trials$init(trial = trial)
my_trials
}
bernoulli_trial() %>% trial_init()
## # A tibble: 2 × 2
## t1_outcome t1_prob
## <int> <dbl>
## 1 0 0.75
## 2 1 0.25
trial_advance <- function(trials, increment = 1){
my_trials <- trials
my_trials$update(increment = increment)
my_trials
}
add_trials <- function(trials, increment = 1){
if(!is.R6(trials)){my_trials <- trial_init(trial = trials)
my_trials <- trial_advance(trials = my_trials,
increment = increment -1)
}
if(is.R6(trials)){my_trials <- trial_advance(trials = trials,
increment = increment)}
my_trials
}
bernoulli_trial() %>%
trial_init() %>%
trial_advance() %>%
trial_advance(2)
## # A tibble: 16 × 8
## t1_outcome t1_prob t2_outcome t2_prob t3_outcome t3_prob t4_outcome t4_prob
## <int> <dbl> <int> <dbl> <int> <dbl> <int> <dbl>
## 1 0 0.75 0 0.75 0 0.75 0 0.75
## 2 0 0.75 0 0.75 0 0.75 1 0.25
## 3 0 0.75 0 0.75 1 0.25 0 0.75
## 4 0 0.75 0 0.75 1 0.25 1 0.25
## 5 0 0.75 1 0.25 0 0.75 0 0.75
## 6 0 0.75 1 0.25 0 0.75 1 0.25
## 7 0 0.75 1 0.25 1 0.25 0 0.75
## 8 0 0.75 1 0.25 1 0.25 1 0.25
## 9 1 0.25 0 0.75 0 0.75 0 0.75
## 10 1 0.25 0 0.75 0 0.75 1 0.25
## 11 1 0.25 0 0.75 1 0.25 0 0.75
## 12 1 0.25 0 0.75 1 0.25 1 0.25
## 13 1 0.25 1 0.25 0 0.75 0 0.75
## 14 1 0.25 1 0.25 0 0.75 1 0.25
## 15 1 0.25 1 0.25 1 0.25 0 0.75
## 16 1 0.25 1 0.25 1 0.25 1 0.25
sum_across <- function(data, var_key = "outcome"){
dplyr::mutate(.data = data,
global_outcome =
rowSums(dplyr::across(dplyr::contains(var_key))))
}
seq_across <- function(data, var_key = "outcome"){
col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
paste_collapse <- function(x){paste(x, collapse = ", ")}
data$global_outcome <- apply(data[,col_list], MARGIN = 1, FUN = paste_collapse)
data
}
prod_across <- function(data, var_key = "prob"){
col_list <- names(data)[names(data) %>% stringr::str_detect(var_key)]
data$global_probs <- apply(data[,col_list], MARGIN = 1, FUN = prod)
data
}
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
bernoulli_trial(prob = .5) %>%
add_trials() %>%
add_trials() %>%
add_trials(5) %>%
.$out %>%
sum_across() %>%
prod_across() %>%
group_by(global_outcome) %>%
summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
## global_outcome probs
## <dbl> <dbl>
## 1 0 0.00781
## 2 1 0.0547
## 3 2 0.164
## 4 3 0.273
## 5 4 0.273
## 6 5 0.164
## 7 6 0.0547
## 8 7 0.00781
bernoulli_trial(prob = .5) %>%
add_trials(3) %>%
.$out %>%
seq_across() %>%
prod_across() %>%
group_by(global_outcome) %>%
summarize(probs = sum(global_probs))
## # A tibble: 8 × 2
## global_outcome probs
## <chr> <dbl>
## 1 0, 0, 0 0.125
## 2 0, 0, 1 0.125
## 3 0, 1, 0 0.125
## 4 0, 1, 1 0.125
## 5 1, 0, 0 0.125
## 6 1, 0, 1 0.125
## 7 1, 1, 0 0.125
## 8 1, 1, 1 0.125