Skip to content

Commit e02d339

Browse files
authored
Merge pull request #527 from seb09/prevent-loo-in-vfold_cv
Prevent LOO through vfold_cv()
2 parents 5c8d38e + c24701c commit e02d339

File tree

7 files changed

+35
-5
lines changed

7 files changed

+35
-5
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
* Improved documentation and formatting: function names are now more easily identifiable through either `()` at the end or being links to the function documentation (@brshallo , #521).
1818

19+
* `vfold_cv()` and `clustering_cv()` now error on implicit leave-one-out cross-validation (@seb09, #527).
20+
1921
## Bug fixes
2022

2123
* `vfold_cv()` now utilizes the `breaks` argument correctly for repeated cross-validation (@ZWael, #471).

R/loo.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#' @export
1414
loo_cv <- function(data, ...) {
1515
check_dots_empty()
16-
split_objs <- vfold_splits(data = data, v = nrow(data))
16+
split_objs <- vfold_splits(data = data, v = nrow(data), prevent_loo = FALSE)
1717
split_objs <-
1818
list(
1919
splits = map(split_objs$splits, change_class),

R/vfold.R

+10-4
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ vfold_cv <- function(data, v = 10, repeats = 1,
122122
}
123123

124124

125-
vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1) {
125+
vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, prevent_loo = TRUE) {
126126

127127
n <- nrow(data)
128-
check_v(v, n, call = rlang::caller_env())
128+
check_v(v, n, prevent_loo = prevent_loo, call = rlang::caller_env())
129129

130130
if (is.null(strata)) {
131131
folds <- sample(rep(1:v, length.out = n))
@@ -311,7 +311,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance, strata = NULL, po
311311
if (is.null(v)) {
312312
v <- max_v
313313
}
314-
check_v(v = v, max_v = max_v, rows = "groups", call = rlang::caller_env())
314+
check_v(v = v, max_v = max_v, rows = "groups", prevent_loo = FALSE, call = rlang::caller_env())
315315

316316
indices <- make_groups(data, group, v, balance, strata)
317317
indices <- lapply(indices, default_complement, n = nrow(data))
@@ -332,13 +332,19 @@ add_vfolds <- function(x, v) {
332332
x
333333
}
334334

335-
check_v <- function(v, max_v, rows = "rows", call = rlang::caller_env()) {
335+
check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) {
336336
if (!is.numeric(v) || length(v) != 1 || v < 2) {
337337
rlang::abort("`v` must be a single positive integer greater than 1", call = call)
338338
} else if (v > max_v) {
339339
rlang::abort(
340340
glue::glue("The number of {rows} is less than `v = {v}`"), call = call
341341
)
342+
} else if (prevent_loo && isTRUE(v == max_v)) {
343+
cli_abort(c(
344+
"Leave-one-out cross-validation is not supported by this function.",
345+
"x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.",
346+
"i" = "Use `loo_cv()` in this case."
347+
), call = call)
342348
}
343349
}
344350

tests/testthat/_snaps/clustering.md

+10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030

3131
`repeats` must be a single positive integer
3232

33+
---
34+
35+
Code
36+
clustering_cv(mtcars, mpg, v = nrow(mtcars))
37+
Condition
38+
Error in `clustering_cv()`:
39+
! Leave-one-out cross-validation is not supported by this function.
40+
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
41+
i Use `loo_cv()` in this case.
42+
3343
# printing
3444

3545
Code

tests/testthat/_snaps/vfold.md

+10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@
3535

3636
`repeats` must be a single positive integer
3737

38+
---
39+
40+
Code
41+
vfold_cv(mtcars, v = nrow(mtcars))
42+
Condition
43+
Error in `vfold_cv()`:
44+
! Leave-one-out cross-validation is not supported by this function.
45+
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
46+
i Use `loo_cv()` in this case.
47+
3848
# printing
3949

4050
Code

tests/testthat/test-clustering.R

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ test_that("bad args", {
4545
expect_snapshot(error = TRUE, clustering_cv(Orange, v = 1, vars = "Tree"))
4646
expect_snapshot_error(clustering_cv(Orange, repeats = 0))
4747
expect_snapshot_error(clustering_cv(Orange, repeats = NULL))
48+
expect_snapshot(error = TRUE, clustering_cv(mtcars, mpg, v = nrow(mtcars)))
4849
})
4950

5051
test_that("printing", {

tests/testthat/test-vfold.R

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ test_that("bad args", {
8585
expect_snapshot_error(vfold_cv(iris, v = 150, repeats = 2))
8686
expect_snapshot_error(vfold_cv(Orange, repeats = 0))
8787
expect_snapshot_error(vfold_cv(Orange, repeats = NULL))
88+
expect_snapshot(error = TRUE, vfold_cv(mtcars, v = nrow(mtcars)))
8889
})
8990

9091
test_that("printing", {

0 commit comments

Comments
 (0)