Huber Regression
Introduction
Huber regression (Huber 1964) is a regression technique that is robust to outliers. The idea is to use a different loss function rather than the traditional least-squares; we solve
\[\begin{array}{ll} \underset{\beta}{\mbox{minimize}} & \sum_{i=1}^m \phi(y_i - x_i^T\beta) \end{array}\]for variable \(\beta \in {\mathbf R}^n\), where the loss \(\phi\) is the Huber function with threshold \(M > 0\), \[ \phi(u) = \begin{cases} u^2 & \mbox{if } |u| \leq M \\ 2Mu - M^2 & \mbox{if } |u| > M. \end{cases} \]
This function is identical to the least squares penalty for small residuals, but on large residuals, its penalty is lower and increases linearly rather than quadratically. It is thus more forgiving of outliers.
Example
We generate some problem data.
n <- 1
m <- 450
M <- 1 ## Huber threshold
p <- 0.1 ## Fraction of responses with sign flipped
## Generate problem data
set.seed(1289)
beta_true <- 5 * matrix(stats::rnorm(n), nrow = n)
X <- matrix(stats::rnorm(m * n), nrow = m, ncol = n)
y_true <- X %*% beta_true
eps <- matrix(stats::rnorm(m), nrow = m)
We will randomly flip the sign of some responses to illustrate the robustness.
factor <- 2*stats::rbinom(m, size = 1, prob = 1-p) - 1
y <- factor * y_true + eps
We can solve this problem both using ordinary least squares and huber regression to compare.
beta <- Variable(n)
rel_err <- norm(beta - beta_true, "F") / norm(beta_true, "F")
## OLS
obj <- sum((y - X %*% beta)^2)
prob <- Problem(Minimize(obj))
result <- solve(prob)
beta_ols <- result$getValue(beta)
err_ols <- result$getValue(rel_err)
## Solve Huber regression problem
obj <- sum(CVXR::huber(y - X %*% beta, M))
prob <- Problem(Minimize(obj))
result <- solve(prob)
beta_hub <- result$getValue(beta)
err_hub <- result$getValue(rel_err)
Finally, we also solve the OLS problem assuming we know the flipped signs.
## Solve ordinary least squares assuming sign flips known
obj <- sum((y - factor*(X %*% beta))^2)
prob <- Problem(Minimize(obj))
result <- solve(prob)
beta_prs <- result$getValue(beta)
err_prs <- result$getValue(rel_err)
We can now plot the fit against the measured responses.
d1 <- data.frame(X = X, y = y, sign = as.factor(factor))
d2 <- data.frame(X = rbind(X, X, X),
yHat = rbind(X %*% beta_ols,
X %*% beta_hub,
X %*% beta_prs),
Estimate = c(rep("OLS", m),
rep("Huber", m),
rep("Prescient", m)))
ggplot() +
geom_point(data = d1, mapping = aes(x = X, y = y, color = sign)) +
geom_line(data = d2, mapping = aes(x = X, y = yHat, color = Estimate))
As can be seen, the Huber line is closer to the prescient line.
## Testthat Results: No output is good
Session Info
sessionInfo()
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-apple-darwin20
## Running under: macOS Sequoia 15.1
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: America/Los_Angeles
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] ggplot2_3.5.1 CVXR_1.0-15 testthat_3.2.1.1 here_1.0.1
##
## loaded via a namespace (and not attached):
## [1] gmp_0.7-5 clarabel_0.9.0.1 sass_0.4.9 utf8_1.2.4
## [5] generics_0.1.3 slam_0.1-54 blogdown_1.19 lattice_0.22-6
## [9] digest_0.6.37 magrittr_2.0.3 evaluate_1.0.1 grid_4.4.2
## [13] bookdown_0.41 pkgload_1.4.0 fastmap_1.2.0 rprojroot_2.0.4
## [17] jsonlite_1.8.9 Matrix_1.7-1 brio_1.1.5 Rmosek_10.2.0
## [21] fansi_1.0.6 scales_1.3.0 codetools_0.2-20 jquerylib_0.1.4
## [25] cli_3.6.3 Rmpfr_0.9-5 rlang_1.1.4 Rglpk_0.6-5.1
## [29] bit64_4.5.2 munsell_0.5.1 withr_3.0.2 cachem_1.1.0
## [33] yaml_2.3.10 tools_4.4.2 osqp_0.6.3.3 Rcplex_0.3-6
## [37] rcbc_0.1.0.9001 dplyr_1.1.4 colorspace_2.1-1 gurobi_11.0-0
## [41] assertthat_0.2.1 vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
## [45] bit_4.5.0 desc_1.4.3 cccp_0.3-1 pkgconfig_2.0.3
## [49] bslib_0.8.0 pillar_1.9.0 gtable_0.3.6 glue_1.8.0
## [53] Rcpp_1.0.13-1 highr_0.11 xfun_0.49 tibble_3.2.1
## [57] tidyselect_1.2.1 knitr_1.48 farver_2.1.2 htmltools_0.5.8.1
## [61] labeling_0.4.3 rmarkdown_2.29 compiler_4.4.2