Skip to content

Commit bf70b7b

Browse files
authored
Merge pull request #548 from tidymodels/checks-vfold
Update input checks for `vfold_cv.R`
2 parents b02f57d + a74d3a4 commit bf70b7b

File tree

5 files changed

+79
-41
lines changed

5 files changed

+79
-41
lines changed

R/clustering.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ clustering_cv <- function(data,
5757
distance_function = "dist",
5858
cluster_function = c("kmeans", "hclust"),
5959
...) {
60-
check_repeats(repeats)
60+
check_number_whole(repeats, min = 1)
6161

6262
if (!rlang::is_function(cluster_function)) {
6363
cluster_function <- rlang::arg_match(cluster_function)

R/vfold.R

+14-16
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ vfold_cv <- function(data, v = 10, repeats = 1,
7272
}
7373

7474
check_strata(strata, data)
75-
check_repeats(repeats)
75+
check_number_whole(repeats, min = 1)
7676

7777
if (repeats == 1) {
7878
split_objs <- vfold_splits(
@@ -213,7 +213,7 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, pr
213213
#' @export
214214
group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) {
215215
check_dots_empty()
216-
check_repeats(repeats)
216+
check_number_whole(repeats, min = 1)
217217
group <- validate_group({{ group }}, data)
218218
balance <- rlang::arg_match(balance)
219219

@@ -331,23 +331,24 @@ add_vfolds <- function(x, v) {
331331
}
332332

333333
check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) {
334-
if (!is.numeric(v) || length(v) != 1 || v < 2) {
335-
cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call)
336-
} else if (v > max_v) {
334+
check_number_whole(v, min = 2, call = call)
335+
336+
if (v > max_v) {
337337
cli_abort(
338338
"The number of {rows} is less than {.arg v} = {.val {v}}.",
339339
call = call
340340
)
341-
} else if (prevent_loo && isTRUE(v == max_v)) {
341+
}
342+
if (prevent_loo && isTRUE(v == max_v)) {
342343
cli_abort(c(
343344
"Leave-one-out cross-validation is not supported by this function.",
344-
"x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.",
345-
"i" = "Use `loo_cv()` in this case."
345+
"x" = "You set {.arg v} to {.code nrow(data)}, which would result in a leave-one-out cross-validation.",
346+
"i" = "Use {.fn loo_cv} in this case."
346347
), call = call)
347348
}
348349
}
349350

350-
check_grouped_strata <- function(group, strata, pool, data) {
351+
check_grouped_strata <- function(group, strata, pool, data, call = caller_env()) {
351352

352353
strata <- tidyselect::vars_select(names(data), !!enquo(strata))
353354

@@ -363,14 +364,11 @@ check_grouped_strata <- function(group, strata, pool, data) {
363364

364365
if (nrow(vctrs::vec_unique(grouped_table)) !=
365366
nrow(vctrs::vec_unique(grouped_table["group"]))) {
366-
cli_abort("{.arg strata} must be constant across all members of each {.arg group}.")
367+
cli_abort(
368+
"{.field strata} must be constant across all members of each {.field group}.",
369+
call = call
370+
)
367371
}
368372

369373
strata
370374
}
371-
372-
check_repeats <- function(repeats, call = rlang::caller_env()) {
373-
if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) {
374-
cli_abort("{.arg repeats} must be a single positive integer.", call = call)
375-
}
376-
}

tests/testthat/_snaps/clustering.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
clustering_cv(iris, Sepal.Length, v = -500)
1313
Condition
1414
Error in `clustering_cv()`:
15-
! `v` must be a single positive integer greater than 1.
15+
! `v` must be a whole number larger than or equal to 2, not the number -500.
1616

1717
---
1818

@@ -36,23 +36,23 @@
3636
clustering_cv(Orange, v = 1, vars = "Tree")
3737
Condition
3838
Error in `clustering_cv()`:
39-
! `v` must be a single positive integer greater than 1.
39+
! `v` must be a whole number larger than or equal to 2, not the number 1.
4040

4141
---
4242

4343
Code
4444
clustering_cv(Orange, repeats = 0)
4545
Condition
4646
Error in `clustering_cv()`:
47-
! `repeats` must be a single positive integer.
47+
! `repeats` must be a whole number larger than or equal to 1, not the number 0.
4848

4949
---
5050

5151
Code
5252
clustering_cv(Orange, repeats = NULL)
5353
Condition
5454
Error in `clustering_cv()`:
55-
! `repeats` must be a single positive integer.
55+
! `repeats` must be a whole number, not `NULL`.
5656

5757
---
5858

tests/testthat/_snaps/vfold.md

+24-16
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,29 @@
4141
! strata cannot be a <Surv> object.
4242
i Use the time or event variable directly.
4343

44-
# bad args
44+
# v arg is checked
4545

4646
Code
4747
vfold_cv(iris, v = -500)
4848
Condition
4949
Error in `vfold_cv()`:
50-
! `v` must be a single positive integer greater than 1.
50+
! `v` must be a whole number larger than or equal to 2, not the number -500.
5151

5252
---
5353

5454
Code
5555
vfold_cv(iris, v = 1)
5656
Condition
5757
Error in `vfold_cv()`:
58-
! `v` must be a single positive integer greater than 1.
58+
! `v` must be a whole number larger than or equal to 2, not the number 1.
5959

6060
---
6161

6262
Code
6363
vfold_cv(iris, v = NULL)
6464
Condition
6565
Error in `vfold_cv()`:
66-
! `v` must be a single positive integer greater than 1.
66+
! `v` must be a whole number, not `NULL`.
6767

6868
---
6969

@@ -76,36 +76,36 @@
7676
---
7777

7878
Code
79-
vfold_cv(iris, v = 150, repeats = 2)
79+
vfold_cv(mtcars, v = nrow(mtcars))
8080
Condition
8181
Error in `vfold_cv()`:
82-
! Repeated resampling when `v` is 150 would create identical resamples.
82+
! Leave-one-out cross-validation is not supported by this function.
83+
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
84+
i Use `loo_cv()` in this case.
8385

84-
---
86+
# repeats arg is checked
8587

8688
Code
87-
vfold_cv(Orange, repeats = 0)
89+
vfold_cv(iris, v = 150, repeats = 2)
8890
Condition
8991
Error in `vfold_cv()`:
90-
! `repeats` must be a single positive integer.
92+
! Repeated resampling when `v` is 150 would create identical resamples.
9193

9294
---
9395

9496
Code
95-
vfold_cv(Orange, repeats = NULL)
97+
vfold_cv(Orange, repeats = 0)
9698
Condition
9799
Error in `vfold_cv()`:
98-
! `repeats` must be a single positive integer.
100+
! `repeats` must be a whole number larger than or equal to 1, not the number 0.
99101

100102
---
101103

102104
Code
103-
vfold_cv(mtcars, v = nrow(mtcars))
105+
vfold_cv(Orange, repeats = NULL)
104106
Condition
105107
Error in `vfold_cv()`:
106-
! Leave-one-out cross-validation is not supported by this function.
107-
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
108-
i Use `loo_cv()` in this case.
108+
! `repeats` must be a whole number, not `NULL`.
109109

110110
# printing
111111

@@ -191,7 +191,7 @@
191191
group_vfold_cv(Orange, v = 1, group = "Tree")
192192
Condition
193193
Error in `group_vfold_cv()`:
194-
! `v` must be a single positive integer greater than 1.
194+
! `v` must be a whole number larger than or equal to 2, not the number 1.
195195

196196
# grouping -- other balance methods
197197

@@ -286,6 +286,14 @@
286286
10 <split [96051/3949]> Resample10
287287
# i 20 more rows
288288

289+
# grouping fails for strata not constant across group members
290+
291+
Code
292+
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
293+
Condition
294+
Error in `group_vfold_cv()`:
295+
! strata must be constant across all members of each group.
296+
289297
# grouping -- printing
290298

291299
Code

tests/testthat/test-vfold.R

+36-4
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ test_that("strata arg is checked", {
104104
})
105105
})
106106

107-
test_that("bad args", {
107+
test_that("v arg is checked", {
108108
expect_snapshot(error = TRUE, {
109109
vfold_cv(iris, v = -500)
110110
})
@@ -117,6 +117,12 @@ test_that("bad args", {
117117
expect_snapshot(error = TRUE, {
118118
vfold_cv(iris, v = 500)
119119
})
120+
expect_snapshot(error = TRUE, {
121+
vfold_cv(mtcars, v = nrow(mtcars))
122+
})
123+
})
124+
125+
test_that("repeats arg is checked", {
120126
expect_snapshot(error = TRUE, {
121127
vfold_cv(iris, v = 150, repeats = 2)
122128
})
@@ -126,9 +132,6 @@ test_that("bad args", {
126132
expect_snapshot(error = TRUE, {
127133
vfold_cv(Orange, repeats = NULL)
128134
})
129-
expect_snapshot(error = TRUE, {
130-
vfold_cv(mtcars, v = nrow(mtcars))
131-
})
132135
})
133136

134137
test_that("printing", {
@@ -403,6 +406,35 @@ test_that("grouping -- strata", {
403406
)
404407
})
405408

409+
test_that("grouping fails for strata not constant across group members", {
410+
set.seed(11)
411+
412+
n_common_class <- 70
413+
n_rare_class <- 30
414+
415+
group_table <- tibble(
416+
group = 1:100,
417+
outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class)))
418+
)
419+
observation_table <- tibble(
420+
group = sample(1:100, 1e5, replace = TRUE),
421+
observation = 1:1e5
422+
)
423+
sample_data <- dplyr::full_join(
424+
group_table,
425+
observation_table,
426+
by = "group",
427+
multiple = "all"
428+
)
429+
430+
# violate requirement
431+
sample_data$outcome[1] <- ifelse(sample_data$outcome[1], 0, 1)
432+
433+
expect_snapshot(error = TRUE, {
434+
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
435+
})
436+
})
437+
406438
test_that("grouping -- repeated", {
407439
set.seed(11)
408440
rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4)

0 commit comments

Comments
 (0)