diff --git a/R/RcppExports.R b/R/RcppExports.R index e22b1d2..017fb2e 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -53,6 +53,10 @@ multiples_of <- function(x, divisor, subset_out = FALSE) { .Call(`_MADMMplasso_multiples_of`, x, divisor, subset_out) } +lm_arma <- function(R, Z) { + .Call(`_MADMMplasso_lm_arma`, R, Z) +} + reg <- function(r, Z) { .Call(`_MADMMplasso_reg`, r, Z) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index dd8e899..5900628 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -101,6 +101,18 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// lm_arma +arma::vec lm_arma(const arma::vec& R, const arma::mat& Z); +RcppExport SEXP _MADMMplasso_lm_arma(SEXP RSEXP, SEXP ZSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::vec& >::type R(RSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Z(ZSEXP); + rcpp_result_gen = Rcpp::wrap(lm_arma(R, Z)); + return rcpp_result_gen; +END_RCPP +} // reg Rcpp::List reg(const arma::mat r, const arma::mat Z); RcppExport SEXP _MADMMplasso_reg(SEXP rSEXP, SEXP ZSEXP) { @@ -143,6 +155,7 @@ static const R_CallMethodDef CallEntries[] = { {"_MADMMplasso_model_p", (DL_FUNC) &_MADMMplasso_model_p, 6}, {"_MADMMplasso_modulo", (DL_FUNC) &_MADMMplasso_modulo, 2}, {"_MADMMplasso_multiples_of", (DL_FUNC) &_MADMMplasso_multiples_of, 3}, + {"_MADMMplasso_lm_arma", (DL_FUNC) &_MADMMplasso_lm_arma, 2}, {"_MADMMplasso_reg", (DL_FUNC) &_MADMMplasso_reg, 2}, {"_MADMMplasso_scale_cpp", (DL_FUNC) &_MADMMplasso_scale_cpp, 2}, {"_MADMMplasso_sqrt_sum_squared_rows", (DL_FUNC) &_MADMMplasso_sqrt_sum_squared_rows, 1}, diff --git a/src/reg.cpp b/src/reg.cpp index a3cbe47..27efc6c 100644 --- a/src/reg.cpp +++ b/src/reg.cpp @@ -1,5 +1,17 @@ #include // [[Rcpp::depends(RcppArmadillo)]] + +// [[Rcpp::export]] +arma::vec lm_arma(const arma::vec &R, const arma::mat &Z) { + // Add a column of ones to Z + arma::mat Z_intercept = arma::join_rows(arma::ones(Z.n_rows), Z); + + // Solve the system of linear equations + arma::vec coefficients = arma::solve(Z_intercept, R); + + return coefficients; +} + // [[Rcpp::export]] Rcpp::List reg( const arma::mat r, @@ -10,9 +22,9 @@ Rcpp::List reg( arma::mat theta01(Z.n_cols, r.n_cols, arma::fill::zeros); for (arma::uword e = 0; e < r.n_cols; e++) { - arma::vec new1 = arma::solve(Z, r.col(e)); + arma::vec new1 = lm_arma(r.col(e), Z); beta01(e) = new1(0); - theta01.col(e) = new1.tail(new1.n_elem); + theta01.col(e) = new1.tail(new1.n_elem - 1); } Rcpp::List out = Rcpp::List::create(