## Simulation data.
set.seed(123)
N <- 100
K <- 4
p <- 50
X <- matrix(rnorm(n = N * p, mean = 0, sd = 1), nrow = N, ncol = p)
Z <- matrix(rbinom(n = N * K, size = 1, prob = 0.5), nrow = N, ncol = K)
## Response model.
beta <- rep(x = 0, times = p)
beta[1:4] <- c(2, -2, 2, 2)
coeffs <- cbind(beta[1], beta[2], beta[3] + 2 * Z[, 1], beta[4] * (1 - 2 * Z[, 2]))
mu <- diag(X[, 1:4] %*% t(coeffs))
y <- mu + 0.5 * rnorm(N, mean = 0, sd = 1)Prototyping the pliable lasso
Introduction
Tibshirani and Friedman (2017) propose a generalization of the lasso that allows the model coefficients to vary as a function of a general set of modifying variables, such as gender, age or time. The pliable lasso model has the form
where
The objective function used for pliable lasso is
In the above,
An R package for the pliable lasso is forthcoming from the authors. Nevertheless, the pliable lasso is an excellent example to highlight the prototyping capabilities of CVXR in research. Along the way, we also illustrate some additional atoms that are actually needed in this example.
The pliable lasso in CVXR
We will use a simulated example from section 3 of Tibshirani and Friedman (2017) with
where
It seems worthwhile to write a function that will fit the model for us so that we can customize a few things such as an intercept term, verbosity etc. The function has the following structure with comments as placeholders for code we shall construct later.
plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
N <- length(y)
p <- ncol(X)
K <- ncol(Z)
beta0 <- 0
if (intercept) {
beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
}
## Define_Parameters
## Build_Penalty_Terms
## Compute_Fitted_Value
## Build_Objective
## Define_and_Solve_Problem
## Return_Values
}
## Fit pliable lasso using CVXR.
#pliable <- pliable_lasso(y, X, Z, alpha = 0.5, lambda = lambda)Defining the parameters
The parameters are easy: we just have
beta <- Variable(p)
theta0 <- Variable(K)
theta <- Variable(c(p, K)) ; theta_transpose <- t(theta)Note that we also define the transpose of
The penalty terms
There are three of them. The first term in the parenthesis, CVXR provides two functions to express this norm:
hstackto bind columns of and the matrix , the equivalent ofrbindin R,cvxr_normwhich accepts a matrix variable and anaxisdenoting the axis along which the norm is to be taken. The penalty requires us to use the row as axis, soaxis = 1per the usual R convention.
The second term in the parenthesis
penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
penalty_term3 <- sum(cvxr_norm(theta, 1))The fitted value
Equation 1 above for + to obtain the XZ_term below.
xz_theta <- lapply(seq_len(p),
function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
XZ_term <- Reduce(f = '+', x = xz_theta)
y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_termThe objective
The objective is now straightforward.
objective <- sum_squares(y - y_hat) / (2 * N) +
(1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
alpha * lambda * penalty_term3The problem and its solution
prob <- Problem(Minimize(objective))
opt_val <- psolve(prob, solver = "CLARABEL", verbose = TRUE)
check_solver_status(prob)
beta_hat <- value(beta)The return values
We create a list with values of interest to us. However, since sparsity is desired, we set values below ZERO_THRESHOLD to zero.
theta0_hat <- value(theta0)
theta_hat <- value(theta)
## Zero out stuff before returning
beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
list(beta0_hat = if (intercept) value(beta0)[1] else 0.0,
beta_hat = beta_hat,
theta0_hat = theta0_hat,
theta_hat = theta_hat,
criterion = opt_val)The full function
We now put it all together.
plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
N <- length(y)
p <- ncol(X)
K <- ncol(Z)
beta0 <- 0
if (intercept) {
beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
}
beta <- Variable(p)
theta0 <- Variable(K)
theta <- Variable(c(p, K)) ; theta_transpose <- t(theta)
penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
penalty_term3 <- sum(cvxr_norm(theta, 1))
xz_theta <- lapply(seq_len(p),
function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
XZ_term <- Reduce(f = '+', x = xz_theta)
y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_term
objective <- sum_squares(y - y_hat) / (2 * N) +
(1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
alpha * lambda * penalty_term3
prob <- Problem(Minimize(objective))
opt_val <- psolve(prob, solver = "CLARABEL", verbose = TRUE)
check_solver_status(prob)
beta_hat <- value(beta)
theta0_hat <- value(theta0)
theta_hat <- value(theta)
## Zero out stuff before returning
beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
list(beta0_hat = if (intercept) value(beta0)[1] else 0.0,
beta_hat = beta_hat,
theta0_hat = theta0_hat,
theta_hat = theta_hat,
criterion = opt_val)
}The Results
Using
result <- plasso_fit(y, X, Z, lambda = 0.6, alpha = 0.5, intercept = FALSE)────────────────────────────────── CVXR v1.8.1 ─────────────────────────────────
ℹ Problem: 3 variables, 0 constraints (DCP)
ℹ Compilation: "CLARABEL" via CVXR::Dcp2Cone -> CVXR::CvxAttr2Constr -> CVXR::ConeMatrixStuffing -> CVXR::Clarabel_Solver
ℹ Compile time: 0.286s
─────────────────────────────── Numerical solver ───────────────────────────────
──────────────────────────────────── Summary ───────────────────────────────────
✔ Status: optimal
✔ Optimal value: 4.27945
ℹ Compile time: 0.286s
ℹ Solver time: 0.023s
We can print the various estimates.
cat(sprintf("Objective value: %f\n", result$criterion))Objective value: 4.279446
We only print the nonzero
index <- which(result$beta_hat != 0)
est.table <- data.frame(matrix(result$beta_hat[index], nrow = 1))
names(est.table) <- paste0("\\(\\beta_{", index, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
kable_styling("striped")| 1.783 | -1.373 | 2.736 | 0.021 | -0.141 | -0.093 | 0.066 |
For this value of
The values for
est.table <- data.frame(matrix(result$theta0_hat, nrow = 1))
names(est.table) <- paste0("\\(\\theta_{0,", 1:K, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
kable_styling("striped")| -0.153 | 0.281 | -0.65 | 0.102 |
And just the first five rows of
est.table <- data.frame(result$theta_hat[1:5, ])
names(est.table) <- paste0("\\(\\theta_{,", 1:K, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
kable_styling("striped")| 0 | 0.000 | 0 | 0 |
| 0 | 0.000 | 0 | 0 |
| 0 | 0.000 | 0 | 0 |
| 0 | -0.093 | 0 | 0 |
| 0 | 0.000 | 0 | 0 |
Final comments
Typically, one would run the fits for various values of
A logistic regression using a pliable lasso model can be prototyped similarly.
Session Info
R version 4.5.2 (2025-10-31)
Platform: aarch64-apple-darwin20
Running under: macOS Tahoe 26.3.1
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.1
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: America/Los_Angeles
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] kableExtra_1.4.0 CVXR_1.8.1
loaded via a namespace (and not attached):
[1] gmp_0.7-5.1 clarabel_0.11.2 xml2_1.5.2 slam_0.1-55
[5] stringi_1.8.7 lattice_0.22-9 digest_0.6.39 magrittr_2.0.4
[9] evaluate_1.0.5 grid_4.5.2 RColorBrewer_1.1-3 fastmap_1.2.0
[13] rprojroot_2.1.1 jsonlite_2.0.0 Matrix_1.7-4 ECOSolveR_0.6.1
[17] backports_1.5.0 scs_3.2.7 Rmosek_11.1.1 viridisLite_0.4.3
[21] scales_1.4.0 codetools_0.2-20 textshaping_1.0.4 cli_3.6.5
[25] rlang_1.1.7 Rglpk_0.6-5.1 yaml_2.3.12 otel_0.2.0
[29] tools_4.5.2 osqp_1.0.0 Rcplex_0.3-8 checkmate_2.3.4
[33] here_1.0.2 gurobi_13.0-1 vctrs_0.7.1 R6_2.6.1
[37] lifecycle_1.0.5 stringr_1.6.0 htmlwidgets_1.6.4 cccp_0.3-3
[41] glue_1.8.0 Rcpp_1.1.1 systemfonts_1.3.1 xfun_0.56
[45] rstudioapi_0.18.0 knitr_1.51 dichromat_2.0-0.1 highs_1.12.0-3
[49] farver_2.1.2 htmltools_0.5.9 rmarkdown_2.30 svglite_2.2.2
[53] piqp_0.6.2 compiler_4.5.2 S7_0.2.1