Skip to content

Commit 617b619

Browse files
authored
Merge pull request #465 from tidymodels/expand-group-intervals
Expand grouping variables for bootstrap intervals
2 parents 8ccee92 + d839b1f commit 617b619

File tree

4 files changed

+180
-38
lines changed

4 files changed

+180
-38
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
* `vfold_cv()` and `clustering_cv()` now error on implicit leave-one-out cross-validation (@seb09, #527).
2020

21+
* Bootstrap intervals via `int_pctl()`, `int_t()`, and `int_bca()` now allow for more flexible grouping (#465).
22+
2123
## Bug fixes
2224

2325
* `vfold_cv()` now utilizes the `breaks` argument correctly for repeated cross-validation (@ZWael, #471).

R/bootci.R

+81-34
Original file line numberDiff line numberDiff line change
@@ -65,41 +65,46 @@ check_tidy <- function(x, std_col = FALSE) {
6565
if (std_col) {
6666
std_candidates <- colnames(x) %in% std_exp
6767
std_candidates <- colnames(x)[std_candidates]
68+
re_name <- list(std_err = std_candidates)
6869
if (has_id) {
6970
x <-
70-
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates)) %>%
71-
mutate(id = (id == "Apparent")) %>%
72-
setNames(c("term", "estimate", "orig", "std_err"))
71+
dplyr::select(x, term, estimate, id, tidyselect::one_of(std_candidates),
72+
dplyr::starts_with(".")) %>%
73+
mutate(orig = (id == "Apparent")) %>%
74+
dplyr::rename(!!!re_name)
7375
} else {
7476
x <-
75-
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates)) %>%
76-
setNames(c("term", "estimate", "std_err"))
77+
dplyr::select(x, term, estimate, tidyselect::one_of(std_candidates),
78+
dplyr::starts_with(".")) %>%
79+
dplyr::rename(!!!re_name)
7780
}
7881
} else {
7982
if (has_id) {
8083
x <-
81-
dplyr::select(x, term, estimate, id) %>%
84+
dplyr::select(x, term, estimate, id, dplyr::starts_with(".")) %>%
8285
mutate(orig = (id == "Apparent")) %>%
8386
dplyr::select(-id)
8487
} else {
85-
x <- dplyr::select(x, term, estimate)
88+
x <- dplyr::select(x, term, estimate, dplyr::starts_with("."))
8689
}
8790
}
8891

8992
x
9093
}
9194

9295

93-
get_p0 <- function(x, alpha = 0.05) {
96+
get_p0 <- function(x, alpha = 0.05, groups) {
97+
group_sym <- rlang::syms(groups)
98+
9499
orig <- x %>%
95-
group_by(term) %>%
100+
group_by(!!!group_sym) %>%
96101
dplyr::filter(orig) %>%
97-
dplyr::select(term, theta_0 = estimate) %>%
102+
dplyr::select(!!!group_sym, theta_0 = estimate) %>%
98103
ungroup()
99104
x %>%
100105
dplyr::filter(!orig) %>%
101-
inner_join(orig, by = "term") %>%
102-
group_by(term) %>%
106+
inner_join(orig, by = groups) %>%
107+
group_by(!!!group_sym) %>%
103108
summarize(p0 = mean(estimate <= theta_0, na.rm = TRUE)) %>%
104109
mutate(
105110
Z0 = stats::qnorm(p0),
@@ -172,9 +177,10 @@ pctl_single <- function(stats, alpha = 0.05) {
172177
#' @param statistics An unquoted column name or `dplyr` selector that identifies
173178
#' a single column in the data set containing the individual bootstrap
174179
#' estimates. This must be a list column of tidy tibbles (with columns
175-
#' `term` and `estimate`). For t-intervals, a
176-
#' standard tidy column (usually called `std.err`) is required.
177-
#' See the examples below.
180+
#' `term` and `estimate`). Optionally, users can include columns whose names
181+
#' begin with a period and the intervals will be created for each combination
182+
#' of these variables and the `term` column. For t-intervals, a standard tidy
183+
#' column (usually called `std.err`) is required. See the examples below.
178184
#' @param alpha Level of significance.
179185
#' @param .fn A function to calculate statistic of interest. The
180186
#' function should take an `rsplit` as the first argument and the `...` are
@@ -200,12 +206,15 @@ pctl_single <- function(stats, alpha = 0.05) {
200206
#' Application_. Cambridge: Cambridge University Press.
201207
#' doi:10.1017/CBO9780511802843
202208
#'
203-
#' @examplesIf rlang::is_installed("broom")
209+
#' @examplesIf rlang::is_installed("broom") & rlang::is_installed("modeldata")
204210
#' \donttest{
205211
#' library(broom)
206212
#' library(dplyr)
207213
#' library(purrr)
208214
#' library(tibble)
215+
#' library(tidyr)
216+
#'
217+
#' # ------------------------------------------------------------------------------
209218
#'
210219
#' lm_est <- function(split, ...) {
211220
#' lm(mpg ~ disp + hp, data = analysis(split)) %>%
@@ -221,6 +230,8 @@ pctl_single <- function(stats, alpha = 0.05) {
221230
#' int_t(car_rs, results)
222231
#' int_bca(car_rs, results, .fn = lm_est)
223232
#'
233+
#' # ------------------------------------------------------------------------------
234+
#'
224235
#' # putting results into a tidy format
225236
#' rank_corr <- function(split) {
226237
#' dat <- analysis(split)
@@ -237,6 +248,31 @@ pctl_single <- function(stats, alpha = 0.05) {
237248
#' bootstraps(Sacramento, 1000, apparent = TRUE) %>%
238249
#' mutate(correlations = map(splits, rank_corr)) %>%
239250
#' int_pctl(correlations)
251+
#'
252+
#' # ------------------------------------------------------------------------------
253+
#' # An example of computing the interval for each value of a custom grouping
254+
#' # factor (type of house in this example)
255+
#'
256+
#' # Get regression estimates for each house type
257+
#' lm_est <- function(split, ...) {
258+
#' analysis(split) %>%
259+
#' tidyr::nest(.by = c(type)) %>%
260+
#' # Compute regression estimates for each house type
261+
#' mutate(
262+
#' betas = purrr::map(data, ~ lm(log10(price) ~ sqft, data = .x) %>% tidy())
263+
#' ) %>%
264+
#' # Convert the column name to begin with a period
265+
#' rename(.type = type) %>%
266+
#' select(.type, betas) %>%
267+
#' unnest(cols = betas)
268+
#' }
269+
#'
270+
#' set.seed(52156)
271+
#' house_rs <-
272+
#' bootstraps(Sacramento, 1000, apparent = TRUE) %>%
273+
#' mutate(results = map(splits, lm_est))
274+
#'
275+
#' int_pctl(house_rs, results)
240276
#' }
241277
#' @export
242278
int_pctl <- function(.data, ...) {
@@ -263,8 +299,11 @@ int_pctl.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {
263299

264300
check_num_resamples(stats, B = 1000)
265301

302+
stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
303+
stat_groups <- rlang::syms(stat_groups)
304+
266305
vals <- stats %>%
267-
dplyr::group_by(term) %>%
306+
dplyr::group_by(!!!stat_groups) %>%
268307
dplyr::do(pctl_single(.$estimate, alpha = alpha)) %>%
269308
dplyr::ungroup()
270309
vals
@@ -343,9 +382,10 @@ int_t.bootstraps <- function(.data, statistics, alpha = 0.05, ...) {
343382

344383
check_num_resamples(stats, B = 500)
345384

346-
vals <-
347-
stats %>%
348-
dplyr::group_by(term) %>%
385+
stat_groups <- c("term", grep("^\\.", names(stats), value = TRUE))
386+
stat_groups <- rlang::syms(stat_groups)
387+
vals <- stats %>%
388+
dplyr::group_by(!!!stat_groups) %>%
349389
dplyr::do(t_single(.$estimate, .$std_err, .$orig, alpha = alpha)) %>%
350390
dplyr::ungroup()
351391
vals
@@ -361,8 +401,11 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
361401
cli_abort("All statistics have missing values.")
362402
}
363403

404+
stat_groups_chr <- c("term", grep("^\\.", names(stats), value = TRUE))
405+
stat_groups_sym <- rlang::syms(stat_groups_chr)
406+
364407
### Estimating Z0 bias-correction
365-
bias_corr_stats <- get_p0(stats, alpha = alpha)
408+
bias_corr_stats <- get_p0(stats, alpha = alpha, groups = stat_groups_chr)
366409

367410
# need the original data frame here
368411
loo_rs <- loo_cv(orig_data)
@@ -380,16 +423,16 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
380423

381424
loo_estimate <-
382425
loo_res %>%
383-
dplyr::group_by(term) %>%
426+
dplyr::group_by(!!!stat_groups_sym) %>%
384427
dplyr::summarize(loo = mean(estimate, na.rm = TRUE)) %>%
385-
dplyr::inner_join(loo_res, by = "term", multiple = "all") %>%
386-
dplyr::group_by(term) %>%
428+
dplyr::inner_join(loo_res, by = stat_groups_chr, multiple = "all") %>%
429+
dplyr::group_by(!!!stat_groups_sym) %>%
387430
dplyr::summarize(
388431
cubed = sum((loo - estimate)^3),
389432
squared = sum((loo - estimate)^2)
390433
) %>%
391434
dplyr::ungroup() %>%
392-
dplyr::inner_join(bias_corr_stats, by = "term") %>%
435+
dplyr::inner_join(bias_corr_stats, by = stat_groups_chr) %>%
393436
dplyr::mutate(
394437
a = cubed / (6 * (squared^(3 / 2))),
395438
Zu = (Z0 + Za) / (1 - a * (Z0 + Za)) + Z0,
@@ -400,21 +443,25 @@ bca_calc <- function(stats, orig_data, alpha = 0.05, .fn, ...) {
400443

401444
terms <- loo_estimate$term
402445
stats <- stats %>% dplyr::filter(!orig)
403-
for (i in seq_along(terms)) {
404-
tmp <- new_stats(stats$estimate[stats$term == terms[i]],
405-
lo = loo_estimate$lo[i],
406-
hi = loo_estimate$hi[i]
407-
)
408-
tmp$term <- terms[i]
446+
447+
keys <- stats %>% dplyr::distinct(!!!stat_groups_sym)
448+
for (i in seq_len(nrow(keys))) {
449+
tmp_stats <- dplyr::inner_join(stats, keys[i,], by = stat_groups_chr)
450+
tmp_loo <- dplyr::inner_join(loo_estimate, keys[i,], by = stat_groups_chr)
451+
452+
tmp <- new_stats(tmp_stats$estimate,
453+
lo = tmp_loo$lo,
454+
hi = tmp_loo$hi)
455+
tmp <- dplyr::bind_cols(tmp, keys[i,])
409456
if (i == 1) {
410457
ci_bca <- tmp
411458
} else {
412-
ci_bca <- bind_rows(ci_bca, tmp)
459+
ci_bca <- dplyr::bind_rows(ci_bca, tmp)
413460
}
414461
}
415462
ci_bca <-
416463
ci_bca %>%
417-
dplyr::select(term, .lower, .estimate, .upper) %>%
464+
dplyr::select(!!!stat_groups_sym, .lower, .estimate, .upper) %>%
418465
dplyr::mutate(
419466
.alpha = alpha,
420467
.method = "BCa"
@@ -441,7 +488,7 @@ int_bca.bootstraps <- function(.data, statistics, alpha = 0.05, .fn, ...) {
441488
if (length(column_name) != 1) {
442489
cli_abort(stat_fmt_err)
443490
}
444-
stats <- .data %>% dplyr::select(!!column_name, id)
491+
stats <- .data %>% dplyr::select(!!column_name, id, dplyr::starts_with("."))
445492
stats <- check_tidy(stats)
446493

447494
check_num_resamples(stats, B = 1000)

man/int_pctl.Rd

+35-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-bootci.R

+62
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,65 @@ test_that("regression intervals", {
285285
"must be a single numeric value"
286286
)
287287
})
288+
289+
test_that("compute intervals with additional grouping terms", {
290+
skip_if_not_installed("broom")
291+
292+
lm_coefs <- function(dat) {
293+
mod <- lm(mpg ~ I(1/disp), data = dat)
294+
tidy(mod)
295+
}
296+
297+
coef_by_engine_shape <- function(split, ...) {
298+
split %>%
299+
analysis() %>%
300+
dplyr::rename(.vs = vs) %>%
301+
tidyr::nest(.by = .vs) %>%
302+
dplyr::mutate(coefs = map(data, lm_coefs)) %>%
303+
dplyr::select(-data) %>%
304+
tidyr::unnest(coefs)
305+
}
306+
307+
set.seed(270)
308+
boot_rs <-
309+
bootstraps(mtcars, 1000, apparent = TRUE) %>%
310+
dplyr::mutate(results = purrr::map(splits, coef_by_engine_shape))
311+
312+
pctl_res <- int_pctl(boot_rs, results)
313+
t_res <- int_t(boot_rs, results)
314+
bca_res <- int_bca(boot_rs, results, .fn = coef_by_engine_shape)
315+
316+
exp_ptype <-
317+
tibble::tibble(
318+
term = character(0),
319+
.vs = numeric(0),
320+
.lower = numeric(0),
321+
.estimate = numeric(0),
322+
.upper = numeric(0),
323+
.alpha = numeric(0),
324+
.method = character(0)
325+
)
326+
327+
expect_equal(pctl_res[0, ], exp_ptype)
328+
expect_equal(t_res[0, ], exp_ptype)
329+
expect_equal(bca_res[0, ], exp_ptype)
330+
331+
exp_combos <-
332+
tibble::tribble(
333+
~term, ~.vs,
334+
"(Intercept)", 0,
335+
"(Intercept)", 1,
336+
"I(1/disp)", 0,
337+
"I(1/disp)", 1
338+
)
339+
340+
group_patterns <- function(x) {
341+
dplyr::distinct(x, term, .vs) %>%
342+
dplyr::arrange(term, .vs)
343+
}
344+
345+
expect_equal(group_patterns(pctl_res), exp_combos)
346+
expect_equal(group_patterns(t_res), exp_combos)
347+
expect_equal(group_patterns(bca_res), exp_combos)
348+
})
349+

0 commit comments

Comments
 (0)