Skip to content

Commit 807749f

Browse files
committed
fix bug in smote when only factor features are presente
1 parent b42fb91 commit 807749f

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

NEWS

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mlr_2.2
22
- basic R learners now have new slots: name (a descriptive name of the algorithm),
33
short.name (abbreviation that can be used in plots and tables) and note
44
(notes regarding slight changes for the mlr integration of the learner and such).
5+
- fix a bug in "smote" when only factor features are present
56

67
- new learners:
78
-- classif.lqa

R/smote.R

+5-4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ smote = function(task, rate, nn = 5L) {
4040
y = data[, target]
4141
x = dropNamed(data, target)
4242
z = getMinMaxClass(y)
43+
if (z$min.size < nn)
44+
stopf("You cannot set nn = %i, when the minimal class has size %i!", nn, z$min.size)
4345
x.min = x[z$min.inds, , drop = FALSE]
4446
n.min = nrow(x.min)
4547
n.new = round(rate * n.min) - n.min
@@ -50,14 +52,13 @@ smote = function(task, rate, nn = 5L) {
5052
# convert xmin to matrix, so we can handle it better in C
5153
# factors are integer levels
5254
x.min.matrix = x.min
53-
if (any(is.num)) {
55+
if (any(!is.num)) {
5456
for (i in 1:ncol(x.min.matrix)) {
5557
if (!is.num[i])
56-
x.min.matrix[, i] = as.integer(x.min.matrix[, i])
58+
x.min.matrix[, i] = as.numeric(as.integer(x.min.matrix[, i]))
5759
}
5860
}
5961
x.min.matrix = as.matrix(x.min.matrix)
60-
6162
# dist matrix on smaller class, diag = 0 so we dont find x as neighnor of x
6263
minclass.dist = as.matrix(daisy(x.min))
6364
diag(minclass.dist) = NA
@@ -69,7 +70,7 @@ smote = function(task, rate, nn = 5L) {
6970
res = as.data.frame(res)
7071

7172
# convert ints back to factors
72-
if (any(is.num)) {
73+
if (any(!is.num)) {
7374
for (i in 1:ncol(res)) {
7475
if (!is.num[i])
7576
res[, i] = as.factor(as.integer(res[, i]))

src/smote.c

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ SEXP c_smote(SEXP s_x, SEXP s_isnum, SEXP s_nn, SEXP s_res) {
1616
/* select a random minority obs and random neighbor */
1717
j_sel = runif(0, nrow_x);
1818
j_nn = runif(0, ncol_nn);
19-
/* matrix nn contains indexes of ncol_nn nearest neighbors for each minoriy obs (= rows);
20-
// as the indexes originate from R they are one-based, so the randomly chosen index has
21-
// to be subtracted by one in order to select the right row (j_nn) of zero-based matrix x
22-
*/
19+
/* matrix nn contains indexes of ncol_nn nearest neighbors for each minoriy obs (= rows)
20+
as the indexes originate from R they are one-based, so the randomly chosen index has
21+
to be subtracted by one in order to select the right row (j_nn) of zero-based matrix x */
2322
j_nn = nn[j_sel + j_nn * nrow_nn] - 1;
2423
lambda = unif_rand();
2524
for (R_len_t col = 0; col < ncol_x; col++) {

tests/testthat/test_imbal_smote.R

+20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ test_that("smote works", {
88
tab2 = table(df[, binaryclass.target])
99
expect_equal(tab2["M"], tab1["M"])
1010
expect_equal(tab2["R"], tab1["R"] * 2)
11+
12+
# check trivial error check
13+
d = data.frame(
14+
x1 = rep(c("a", "b"), 3, replace = TRUE),
15+
y = rep(c("a", "b"), 3, replace = TRUE)
16+
)
17+
task = makeClassifTask(data = d, target = "y")
18+
expect_error(smote(task, rate = 2), "minimal class has size 3")
1119
})
1220

1321
test_that("smote works with rate 1 (no new examples)", {
@@ -20,6 +28,18 @@ test_that("smote works with rate 1 (no new examples)", {
2028
expect_equal(tab2["R"], tab1["R"])
2129
})
2230

31+
test_that("smote works with only factor fetaures", {
32+
n = 10
33+
d = data.frame(
34+
x1 = sample(c("a", "b"), n, replace = TRUE),
35+
x2 = sample(c("a", "b"), n, replace = TRUE),
36+
y = sample(c("a", "b"), n, replace = TRUE)
37+
)
38+
task = makeClassifTask(data = d, target = "y")
39+
task2 = smote(task, rate = 1.2, nn = 2L)
40+
expect_equal(task2$task.desc$size, 11)
41+
})
42+
2343
test_that("smote wrapper", {
2444
rdesc = makeResampleDesc("CV", iters = 2)
2545
lrn1 = makeLearner("classif.rpart")

0 commit comments

Comments
 (0)