Skip to content

Commit 4b853fe

Browse files
authored
Merge pull request #310 from tidymodels/a_little_unbalanced
Remove balance
2 parents 7752ef1 + 18d0113 commit 4b853fe

File tree

6 files changed

+8
-98
lines changed

6 files changed

+8
-98
lines changed

NEWS.md

-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
* Added better printing methods for initial split objects.
1818

19-
* Added a new `balance` option to `group_vfold_cv()` to balance folds either by the number of groups or the number of observations (@mikemahoney218, #300).
20-
2119
# rsample 0.1.1
2220

2321
* Updated documentation on stratified sampling (#245).

R/groups.R

+4-9
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
#'
2323
#' set.seed(123)
2424
#' group_vfold_cv(Sacramento, group = city, v = 5)
25-
#' group_vfold_cv(Sacramento, group = city, v = 5, balance = "observations")
2625
#'
2726
#' @export
28-
group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "observations"), ...) {
27+
group_vfold_cv <- function(data, group = NULL, v = NULL, ...) {
2928
if (!missing(group)) {
3029
group <- tidyselect::vars_select(names(data), !!enquo(group))
3130
if (length(group) == 0) {
@@ -42,9 +41,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "
4241
rlang::abort("`group` should be a column in `data`.")
4342
}
4443

45-
balance <- rlang::arg_match(balance)
46-
47-
split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance)
44+
split_objs <- group_vfold_splits(data = data, group = group, v = v)
4845

4946
## We remove the holdout indices since it will save space and we can
5047
## derive them later when they are needed.
@@ -68,9 +65,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "
6865
)
6966
}
7067

71-
group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "observations")) {
72-
73-
balance <- rlang::arg_match(balance)
68+
group_vfold_splits <- function(data, group, v = NULL) {
7469

7570
group <- getElement(data, group)
7671
max_v <- length(unique(group))
@@ -81,7 +76,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance = c("groups", "obs
8176
check_v(v = v, max_v = max_v, rows = "rows", call = rlang::caller_env())
8277
}
8378

84-
indices <- make_groups(data, group, v, balance)
79+
indices <- make_groups(data, group, v)
8580
indices <- lapply(indices, default_complement, n = nrow(data))
8681
split_objs <-
8782
purrr::map(indices,

R/make_groups.R

+2-35
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,13 @@
88
#' @param group A variable in `data` (single character or name) used for
99
#' grouping observations with the same value to either the analysis or
1010
#' assessment set within a fold.
11-
#' @param balance If `v` is less than the number of unique groups, how should
12-
#' groups be combined into folds? If `"groups"`, the default, then groups are
13-
#' combined randomly to balance the number of groups in each fold.
14-
#' If `"observations"`, then groups are combined to balance the number of
15-
#' observations in each fold.
1611
#'
1712
#' @keywords internal
18-
make_groups <- function(data, group, v, balance) {
13+
make_groups <- function(data, group, v) {
1914
data_ind <- data.frame(..index = 1:nrow(data), ..group = group)
2015
data_ind$..group <- as.character(data_ind$..group)
2116

22-
res <- switch(
23-
balance,
24-
"groups" = balance_groups(data_ind, v),
25-
"observations" = balance_observations(data_ind, v)
26-
)
17+
res <- balance_groups(data_ind, v)
2718
data_ind <- res$data_ind
2819
keys <- res$keys
2920

@@ -47,27 +38,3 @@ balance_groups <- function(data_ind, v) {
4738
keys = keys
4839
)
4940
}
50-
51-
balance_observations <- function(data_ind, v) {
52-
while (vec_unique_count(data_ind$..group) > v) {
53-
freq_table <- vec_count(data_ind$..group)
54-
# Recategorize the largest group to be collapsed
55-
# as the smallest group to be kept
56-
group_to_keep <- vec_slice(freq_table, v)
57-
group_to_collapse <- vec_slice(freq_table, v + 1)
58-
collapse_lgl <- vec_in(data_ind$..group, group_to_collapse$key)
59-
vec_slice(data_ind$..group, collapse_lgl) <- group_to_keep$key
60-
}
61-
unique_groups <- unique(data_ind$..group)
62-
63-
keys <- data.frame(
64-
..group = unique_groups,
65-
..folds = sample(rep(seq_len(v), length.out = length(unique_groups)))
66-
)
67-
68-
list(
69-
data_ind = data_ind,
70-
keys = keys
71-
)
72-
73-
}

man/group_vfold_cv.Rd

+1-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/make_groups.Rd

+1-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-groups.R

-31
Original file line numberDiff line numberDiff line change
@@ -84,37 +84,6 @@ test_that("tibble input", {
8484
expect_true(all(table(sp_out) == 1))
8585
})
8686

87-
test_that("other balance methods", {
88-
data(ames, package = "modeldata")
89-
set.seed(11)
90-
rs1 <- group_vfold_cv(ames, "Neighborhood", balance = "observations", v = 2)
91-
sizes1 <- dim_rset(rs1)
92-
93-
expect_true(all(sizes1$analysis == 1465))
94-
expect_true(all(sizes1$assessment == 1465))
95-
same_data <-
96-
purrr::map_lgl(rs1$splits, function(x) {
97-
all.equal(x$data, ames)
98-
})
99-
expect_true(all(same_data))
100-
101-
good_holdout <- purrr::map_lgl(
102-
rs1$splits,
103-
function(x) {
104-
length(intersect(x$in_ind, x$out_id)) == 0
105-
}
106-
)
107-
expect_true(all(good_holdout))
108-
109-
expect_true(
110-
!any(
111-
unique(as.character(assessment(rs1$splits[[1]])$Neighborhood)) %in%
112-
unique(as.character(analysis(rs1$splits[[1]])$Neighborhood))
113-
)
114-
)
115-
116-
})
117-
11887
test_that("printing", {
11988
expect_snapshot(group_vfold_cv(warpbreaks, "tension"))
12089
})

0 commit comments

Comments
 (0)