Skip to content

Commit ce33c22

Browse files
authored
Merge pull request #113 from mlr-org/add_param6
fix distr6 learners
2 parents b9aa81b + c5d16ac commit ce33c22

6 files changed

+27
-39
lines changed

.lintr

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ linters: with_defaults(
55
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
66
cyclocomp_linter = NULL, # do not check function complexity
77
commented_code_linter = NULL, # allow code in comments
8+
todo_comment_linter = NULL, # allow todo in comments
89
line_length_linter = line_length_linter(100)
910
)

DESCRIPTION

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3extralearners
22
Title: Extra Learners For mlr3
3-
Version: 0.5.5
3+
Version: 0.5.6
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -80,6 +80,7 @@ Suggests:
8080
nnet,
8181
np,
8282
obliqueRSF,
83+
param6,
8384
partykit,
8485
penalized,
8586
pendensity,
@@ -96,7 +97,7 @@ Suggests:
9697
sm,
9798
stats,
9899
survival,
99-
survivalmodels (>= 0.1.4),
100+
survivalmodels (>= 0.1.9),
100101
survivalsvm,
101102
tensorflow (>= 2.0.0),
102103
testthat,

NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# mlr3extralearners 0.5.6
2+
3+
* Fix learners requiring distr6. distr6 1.6.0 now forced and param6 added to suggests
4+
15
# mlr3extralearners 0.5.5
26

37
* Bugfix `regr.gausspr`

R/learner_flexsurv_surv_flexible.R

+18-36
Original file line numberDiff line numberDiff line change
@@ -155,51 +155,35 @@ predict_flexsurvreg <- function(object, task, ...) {
155155
# parameters above.
156156
pdf = function(x) {} # nolint
157157
body(pdf) = substitute({
158-
fn = func
159-
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
160-
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
161-
do.call(fn, c(list(x = x), args))
158+
do.call(func, c(list(x = x), self$parameters()$values))
162159
}, list(func = object$dfns$d))
163160

164161
cdf = function(x) {} # nolint
165162
body(cdf) = substitute({
166-
fn = func
167-
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
168-
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
169-
do.call(fn, c(list(q = x), args))
163+
do.call(func, c(list(q = x), self$parameters()$values))
170164
}, list(func = object$dfns$p))
171165

172166
quantile = function(p) {} # nolint
173167
body(quantile) = substitute({
174-
fn = func
175-
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
176-
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
177-
do.call(fn, c(list(p = p), args))
168+
do.call(func, c(list(p = p), self$parameters()$values))
178169
}, list(func = object$dfns$q))
179170

180171
rand = function(n) {} # nolint
181172
body(rand) = substitute({
182-
fn = func
183-
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
184-
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
185-
do.call(fn, c(list(n = n), args))
173+
do.call(func, c(list(n = n), self$parameters()$values))
186174
}, list(func = object$dfns$r))
187175

188176
# The parameter set combines the auxiliary parameters with the fitted gamma coefficients.
189-
# Whilst the
190-
# user can set these after fitting, this is generally ill-advised.
191-
parameters = distr6::ParameterSet$new(
192-
id = c(names(args), object$dlist$pars),
193-
value = c(list(
194-
numeric(length(object$knots)),
195-
"hazard", "log"), rep(list(0), length(object$dlist$pars))),
196-
settable = rep(TRUE, length(args) + length(object$dlist$pars)),
197-
support = c(
198-
list(set6::Reals$new()^length(object$knots)),
199-
set6::Set$new("hazard", "odds", "normal"),
200-
set6::Set$new("log", "identity"),
201-
rep(list(set6::Reals$new()), length(object$dlist$pars)))
202-
)
177+
# Whilst the user can set these after fitting, this is generally ill-advised.
178+
parameters = param6::ParameterSet$new(c(list(
179+
param6::prm(
180+
"knots", set6::Reals$new()^length(object$knots),
181+
numeric(length(object$knots))
182+
),
183+
param6::prm("scale", set6::Set$new("hazard", "odds", "normal"), "hazard"),
184+
param6::prm("timescale", set6::Set$new("log", "identity"), "log")),
185+
lapply(object$dlist$pars, function(x) param6::prm(x, "reals", 0))
186+
))
203187

204188
pars = data.table::data.table(t(pars))
205189
pargs = data.table::data.table(matrix(args, ncol = ncol(pars), nrow = length(args)))
@@ -217,18 +201,16 @@ predict_flexsurvreg <- function(object, task, ...) {
217201
pdf = pdf, cdf = cdf, quantile = quantile, rand = rand
218202
)
219203

204+
## FIXME - This is bad and needs speeding up
220205
distlist = lapply(pars, function(x) {
221-
x = as.list(x)
222-
names(x) = c(object$dlist$pars, names(args))
223206
yparams = parameters$clone(deep = TRUE)
224-
ind = match(yparams$.__enclos_env__$private$.parameters$id, names(x))
225-
yparams$.__enclos_env__$private$.parameters$value = x[ind]
207+
yparams$values = setNames(as.list(x), c(object$dlist$pars, names(args)))
226208

227209
do.call(distr6::Distribution$new, c(list(parameters = yparams), shared_params))
228210
})
229211

230-
distr = distr6::VectorDistribution$new(distlist,
231-
decorators = c("CoreStatistics", "ExoticStatistics"))
212+
distr = distr6::VectorDistribution$new(
213+
distlist, decorators = c("CoreStatistics", "ExoticStatistics"))
232214

233215
return(list(distr = distr, lp = lp))
234216
}

R/learner_survival_surv_parametric.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ LearnerSurvParametric = R6Class("LearnerSurvParametric", inherit = mlr3proba::Le
216216
},
217217
cdf = function() {
218218
},
219-
parameters = distr6::ParameterSet$new()
219+
parameters = param6::pset()
220220
))
221221

222222
params = rep(params, length(lp))

R/sysdata.rda

11 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)