Prototyping the pliable lasso

Author

Anqi Fu and Balasubramanian Narasimhan

Introduction

Tibshirani and Friedman () propose a generalization of the lasso that allows the model coefficients to vary as a function of a general set of modifying variables, such as gender, age or time. The pliable lasso model has the form

y^=β01+Zθ0+j=1p(Xjβj+Wjθj)

where y^ is the predicted N×1 vector, β0 is a scalar, θ0 is a K-vector, X and Z are N×p and N×K matrices containing values of the predictor and modifying variables respectively with Wj=XjZ denoting the elementwise multiplication of Z by column Xj of X.

The objective function used for pliable lasso is

J(β0,θ0,β,Θ)=12Ni=1N(yiy^i)2+(1α)λj=1p((βj,θj)2+θj2)+αλj,k|θj,k|1.

In the above, Θ is a p×K matrix of parameters with j-th row θj and individual entries θj,k, λ is a tuning parameters. As α1 (but <1), the solution approaches the lasso solution. The default value used is α=0.5.

An R package for the pliable lasso is forthcoming from the authors. Nevertheless, the pliable lasso is an excellent example to highlight the prototyping capabilities of CVXR in research. Along the way, we also illustrate some additional atoms that are actually needed in this example.

The pliable lasso in CVXR

We will use a simulated example from section 3 of Tibshirani and Friedman () with n=100, p=50 and K=4. The response is generated as

y=μ(x)+0.5ϵ;  ϵN(0,1)μ(x)=x1β1+x2β2+x3(β3e+2z1)+x4β4(e2z2);  β=(2,2,2,2,0,0,)

where e=(1,1,,1)T).

## Simulation data.
set.seed(123)
N <- 100
K <- 4
p <- 50
X <- matrix(rnorm(n = N * p, mean = 0, sd = 1), nrow = N, ncol = p)
Z <- matrix(rbinom(n = N * K, size = 1, prob = 0.5), nrow = N, ncol = K)

## Response model.
beta <- rep(x = 0, times = p)
beta[1:4] <- c(2, -2, 2, 2)
coeffs <- cbind(beta[1], beta[2], beta[3] + 2 * Z[, 1], beta[4] * (1 - 2 * Z[, 2]))
mu <- diag(X[, 1:4] %*% t(coeffs))
y <- mu + 0.5 * rnorm(N, mean = 0, sd = 1)

It seems worthwhile to write a function that will fit the model for us so that we can customize a few things such as an intercept term, verbosity etc. The function has the following structure with comments as placeholders for code we shall construct later.

plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
                       ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
    N <- length(y)
    p <- ncol(X)
    K <- ncol(Z)

    beta0 <- 0
    if (intercept) {
        beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
    }
    ## Define_Parameters
    ## Build_Penalty_Terms
    ## Compute_Fitted_Value
    ## Build_Objective
    ## Define_and_Solve_Problem
    ## Return_Values
}

## Fit pliable lasso using CVXR.
#pliable <- pliable_lasso(y, X, Z, alpha = 0.5, lambda = lambda)

Defining the parameters

The parameters are easy: we just have β, θ0 and Θ.

beta <- Variable(p)
theta0 <- Variable(K)
theta <- Variable(c(p, K)) ; theta_transpose <- t(theta)

Note that we also define the transpose of Θ for use later.

The penalty terms

There are three of them. The first term in the parenthesis, j=1p(||(βj,θj)||2), involves components of β and rows of Θ. CVXR provides two functions to express this norm:

  • hstack to bind columns of β and the matrix Θ, the equivalent of rbind in R,
  • cvxr_norm which accepts a matrix variable and an axis denoting the axis along which the norm is to be taken. The penalty requires us to use the row as axis, so axis = 1 per the usual R convention.

The second term in the parenthesis j||θj||2 is also a norm along rows as the θj are rows of Θ. And the last one is simply a 1-norm.

penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
penalty_term3 <- sum(cvxr_norm(theta, 1))

The fitted value

Equation 1 above for y^ contains a sum: j=1p(Xjβj+Wjθj). This requires multiplication of Z by the columns of X component-wise. That is a natural candidate for a map-reduce combination: map the column multiplication function appropriately and reduce using + to obtain the XZ_term below.

xz_theta <- lapply(seq_len(p),
                   function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
XZ_term <- Reduce(f = '+', x = xz_theta)
y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_term

The objective

The objective is now straightforward.

objective <- sum_squares(y - y_hat) / (2 * N) +
    (1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
    alpha * lambda * penalty_term3

The problem and its solution

prob <- Problem(Minimize(objective))
opt_val <- psolve(prob, solver = "CLARABEL", verbose = TRUE)
check_solver_status(prob)
beta_hat <- value(beta)

The return values

We create a list with values of interest to us. However, since sparsity is desired, we set values below ZERO_THRESHOLD to zero.

theta0_hat <- value(theta0)
theta_hat <- value(theta)

## Zero out stuff before returning
beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
list(beta0_hat = if (intercept) value(beta0)[1] else 0.0,
     beta_hat = beta_hat,
     theta0_hat = theta0_hat,
     theta_hat = theta_hat,
     criterion = opt_val)

The full function

We now put it all together.

plasso_fit <- function(y, X, Z, lambda, alpha = 0.5, intercept = TRUE,
                          ZERO_THRESHOLD= 1e-6, verbose = FALSE) {
    N <- length(y)
    p <- ncol(X)
    K <- ncol(Z)

    beta0 <- 0
    if (intercept) {
        beta0 <- Variable(1) * matrix(1, nrow = N, ncol = 1)
    }
    beta <- Variable(p)
    theta0 <- Variable(K)
    theta <- Variable(c(p, K)) ; theta_transpose <- t(theta)
    penalty_term1 <- sum(cvxr_norm(hstack(beta, theta), 2, axis = 1))
    penalty_term2 <- sum(cvxr_norm(theta, 2, axis = 1))
    penalty_term3 <- sum(cvxr_norm(theta, 1))
    xz_theta <- lapply(seq_len(p),
                       function(j) (matrix(X[, j], nrow = N, ncol = K) * Z) %*% theta_transpose[, j])
    XZ_term <- Reduce(f = '+', x = xz_theta)
    y_hat <- beta0 + X %*% beta + Z %*% theta0 + XZ_term
    objective <- sum_squares(y - y_hat) / (2 * N) +
        (1 - alpha) * lambda * (penalty_term1 + penalty_term2) +
        alpha * lambda * penalty_term3
    prob <- Problem(Minimize(objective))
    opt_val <- psolve(prob, solver = "CLARABEL", verbose = TRUE)
    check_solver_status(prob)
    beta_hat <- value(beta)
    theta0_hat <- value(theta0)
    theta_hat <- value(theta)
    
    ## Zero out stuff before returning
    beta_hat[abs(beta_hat) < ZERO_THRESHOLD] <- 0.0
    theta0_hat[abs(theta0_hat) < ZERO_THRESHOLD] <- 0.0
    theta_hat[abs(theta_hat) < ZERO_THRESHOLD] <- 0.0
    list(beta0_hat = if (intercept) value(beta0)[1] else 0.0,
         beta_hat = beta_hat,
         theta0_hat = theta0_hat,
         theta_hat = theta_hat,
         criterion = opt_val)
}

The Results

Using λ=0.6 we fit the pliable lasso without an intercept

result <- plasso_fit(y, X, Z, lambda = 0.6, alpha = 0.5, intercept = FALSE)
────────────────────────────────── CVXR v1.8.1 ─────────────────────────────────
ℹ Problem: 3 variables, 0 constraints (DCP)
ℹ Compilation: "CLARABEL" via CVXR::Dcp2Cone -> CVXR::CvxAttr2Constr -> CVXR::ConeMatrixStuffing -> CVXR::Clarabel_Solver
ℹ Compile time: 0.286s
─────────────────────────────── Numerical solver ───────────────────────────────
──────────────────────────────────── Summary ───────────────────────────────────
✔ Status: optimal
✔ Optimal value: 4.27945
ℹ Compile time: 0.286s
ℹ Solver time: 0.023s

We can print the various estimates.

cat(sprintf("Objective value: %f\n", result$criterion))
Objective value: 4.279446

We only print the nonzero β values.

index <- which(result$beta_hat != 0)
est.table <- data.frame(matrix(result$beta_hat[index], nrow = 1))
names(est.table) <- paste0("\\(\\beta_{", index, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
    kable_styling("striped")
β1 β2 β3 β4 β20 β34 β39
1.783 -1.373 2.736 0.021 -0.141 -0.093 0.066

For this value of λ, the nonzero (β1,β2,β3,β4) are picked up along with a few others (β20,β34,β39).

The values for θ0.

est.table <- data.frame(matrix(result$theta0_hat, nrow = 1))
names(est.table) <- paste0("\\(\\theta_{0,", 1:K, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
    kable_styling("striped")
θ0,1 θ0,2 θ0,3 θ0,4
-0.153 0.281 -0.65 0.102

And just the first five rows of Θ, which happen to contain all the nonzero values for this result.

est.table <- data.frame(result$theta_hat[1:5, ])
names(est.table) <- paste0("\\(\\theta_{,", 1:K, "}\\)")
knitr::kable(est.table, format = "html", escape = FALSE, digits = 3) |>
    kable_styling("striped")
θ,1 θ,2 θ,3 θ,4
0 0.000 0 0
0 0.000 0 0
0 0.000 0 0
0 -0.093 0 0
0 0.000 0 0

Final comments

Typically, one would run the fits for various values of λ and choose one based on cross-validation and assess the prediction against a test set. Here, even a single fit takes a while, but techniques discussed in other articles here can be used to speed up the computations.

A logistic regression using a pliable lasso model can be prototyped similarly.

Session Info

R version 4.5.2 (2025-10-31)
Platform: aarch64-apple-darwin20
Running under: macOS Tahoe 26.3.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] kableExtra_1.4.0 CVXR_1.8.1      

loaded via a namespace (and not attached):
 [1] gmp_0.7-5.1        clarabel_0.11.2    xml2_1.5.2         slam_0.1-55       
 [5] stringi_1.8.7      lattice_0.22-9     digest_0.6.39      magrittr_2.0.4    
 [9] evaluate_1.0.5     grid_4.5.2         RColorBrewer_1.1-3 fastmap_1.2.0     
[13] rprojroot_2.1.1    jsonlite_2.0.0     Matrix_1.7-4       ECOSolveR_0.6.1   
[17] backports_1.5.0    scs_3.2.7          Rmosek_11.1.1      viridisLite_0.4.3 
[21] scales_1.4.0       codetools_0.2-20   textshaping_1.0.4  cli_3.6.5         
[25] rlang_1.1.7        Rglpk_0.6-5.1      yaml_2.3.12        otel_0.2.0        
[29] tools_4.5.2        osqp_1.0.0         Rcplex_0.3-8       checkmate_2.3.4   
[33] here_1.0.2         gurobi_13.0-1      vctrs_0.7.1        R6_2.6.1          
[37] lifecycle_1.0.5    stringr_1.6.0      htmlwidgets_1.6.4  cccp_0.3-3        
[41] glue_1.8.0         Rcpp_1.1.1         systemfonts_1.3.1  xfun_0.56         
[45] rstudioapi_0.18.0  knitr_1.51         dichromat_2.0-0.1  highs_1.12.0-3    
[49] farver_2.1.2       htmltools_0.5.9    rmarkdown_2.30     svglite_2.2.2     
[53] piqp_0.6.2         compiler_4.5.2     S7_0.2.1          

References

Tibshirani, Robert J., and Jerome H. Friedman. 2017. “A Pliable Lasso.” arXiv Preprint Arxiv:1712.00484. https://arxiv.org/abs/1712.00484.