Sparse Solvers
This short tutorial introduces the use of the Lasso Homotopy (ASP) and Orthogonal Matching Pursuit (OMP) solvers. These are sparse solvers that compute the entire regularization path, providing insight into how the support evolves as the regularization parameter changes. For more details on the algorithms and their implementation, see ActiveSetPursuit.jl
We start by importing ACEpotentials
(and possibly other required libraries)
using ACEpotentials
using Random, Plots
using ACEpotentials.Models: fast_evaluator
using SparseArrays
using Plots
Since sparse solvers automatically select the most relevant features, we usually begin with a model that has a large basis. Here, for demonstration purposes, we use a relatively small model.
model = ace1_model(elements = [:Si], order = 3, totaldegree = 12)
P = algebraic_smoothness_prior(model; p = 4)
223×223 LinearAlgebra.Diagonal{Float64, Vector{Float64}}:
1.0 ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅
⋅ 16.0 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 81.0 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 256.0 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ 625.0 ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ 1296.0 … ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋮ ⋮ ⋱ ⋮
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 6561.0 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ 10000.0 ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 14641.0 ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 20736.0
Next, we load a dataset. We split the dataset into training, validation, and test sets. The training set is used to compute the solution path, the validation set is used to select the best solution, and the test set is used to evaluate the final model.
_train_data, test_data, _ = ACEpotentials.example_dataset("Zuo20_Si")
shuffle!(_train_data);
_train_data = _train_data[1:100] # Limit the dataset size for this tutorial
isplit = floor(Int, 0.8 * length(_train_data))
train_data = _train_data[1:isplit]
val_data = _train_data[isplit+1:end]
20-element Vector{ExtXYZ.Atoms{@NamedTuple{species::Vector{AtomsBase.ChemicalSpecies}, force::Vector{Vector{Float64}}, atomic_number::Vector{Int64}, position::Vector{Vector{Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}}, mass::Vector{Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, momenta::Vector{Vector{Float64}}, atomic_symbol::Vector{Symbol}, velocity::Vector{Vector{Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(u^-1/2, eV^1/2), 𝐋 𝐓^-1, nothing}}}}, masses::Vector{Float64}}, @NamedTuple{energy::Float64, periodicity::Tuple{Bool, Bool, Bool}, cell_vectors::Vector{Vector{Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}}}}}:
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [-2.187491, 0.0, 10.716475]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 11.575115, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₃, periodicity = TTT, cell_vectors = [[10.853038, -0.069463, 0.075827], [0.0, 11.063298, -0.007367], [0.0, 0.0, 10.827283]]u"Å")
Atoms(Si₆₃, periodicity = TTT, cell_vectors = [[14.927933, -2.616535, 11.713657], [0.0, 9.077087, -0.6198], [0.0, 0.0, 8.864201]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₁₂, periodicity = TTT, cell_vectors = [[3.866975, 0.0, 0.0], [-1.933487, 3.348898, 0.0], [0.0, 0.0, 37.888459]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [2.187491, 10.716475, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.716475, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.874996, 10.9024]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
Atoms(Si₆₃, periodicity = TTT, cell_vectors = [[11.17765, -0.038528, 0.035857], [0.0, 10.82494, 0.006365], [0.0, 0.0, 10.752163]]u"Å")
Atoms(Si₆₄, periodicity = TTT, cell_vectors = [[10.937456, 0.0, 0.0], [0.0, 10.937456, 0.0], [0.0, 0.0, 10.937456]]u"Å")
We can now assemble the linear system for the training and validation sets.
At, yt, Wt = ACEpotentials.assemble(train_data, model);
Av, yv, Wv = ACEpotentials.assemble(val_data, model);
[ Info: Assembling linear problem.
[ Info: - Creating feature matrix with size (14375, 223).
[ Info: - Beginning assembly with processor count: 1.
Progress: 2%|█ | ETA: 0:09:49
Progress: 6%|██▌ | ETA: 0:04:14
Progress: 10%|████▏ | ETA: 0:02:45
Progress: 20%|████████▎ | ETA: 0:01:29
Progress: 24%|█████████▊ | ETA: 0:01:16
Progress: 28%|███████████▎ | ETA: 0:01:07
Progress: 31%|████████████▊ | ETA: 0:00:59
Progress: 35%|██████████████▍ | ETA: 0:00:52
Progress: 48%|███████████████████▌ | ETA: 0:00:37
Progress: 55%|██████████████████████▌ | ETA: 0:00:29
Progress: 59%|████████████████████████▏ | ETA: 0:00:26
Progress: 62%|█████████████████████████▋ | ETA: 0:00:23
Progress: 75%|██████████████████████████████▊ | ETA: 0:00:15
Progress: 79%|████████████████████████████████▎ | ETA: 0:00:12
Progress: 88%|███████████████████████████████████▉ | ETA: 0:00:07
Progress: 100%|█████████████████████████████████████████| Time: 0:00:52
[ Info: - Assembly completed.
[ Info: Assembling full weight vector.
Progress: 2%|█ | ETA: 0:00:06
Progress: 100%|█████████████████████████████████████████| Time: 0:00:00
[ Info: Assembling linear problem.
[ Info: - Creating feature matrix with size (3695, 223).
[ Info: - Beginning assembly with processor count: 1.
Progress: 10%|████▏ | ETA: 0:00:10
Progress: 25%|██████████▎ | ETA: 0:00:07
Progress: 40%|████████████████▍ | ETA: 0:00:06
Progress: 55%|██████████████████████▌ | ETA: 0:00:04
Progress: 70%|████████████████████████████▊ | ETA: 0:00:03
Progress: 100%|█████████████████████████████████████████| Time: 0:00:09
[ Info: - Assembly completed.
[ Info: Assembling full weight vector.
We can now compute sparse solution paths using the ASP
and OMP
solvers. These solvers support customizable selection criteria for choosing a solution along the path.
The select
keyword controls which solution is returned:
:final
selects the final iterate on the path.(:bysize, n)
selects the solution with exactlyn
active parameters.(:byerror, ε)
selects the smallest solution whose validation error is within a factorε
of the minimum validation error.
The tsvd
keyword controls whether the solution is post-processed using truncated SVD. This is often beneficial for ASP
, as ℓ1-regularization can shrink coefficients toward zero too aggressively.
The actMax
keyword controls the maximum number of active parameters in the solution.
solver_asp = ACEfit.ASP(; P = P, select = :final, tsvd = true, actMax = 100, loglevel = 0);
asp_result = ACEfit.solve(solver_asp, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv);
We can also compute the OMP path, which is a greedy algorithm that selects the most relevant features iteratively.
solver_omp = ACEfit.OMP(; P = P, select = :final, tsvd = false, actMax = 100, loglevel = 0);
omp_result = ACEfit.solve(solver_omp, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv);
To demonstrate the use of the sparse solvers, we will generate models with different numbers of active parameters. We can select the final model, a model with 500 active parameters, and a model with a validation error within 1.3 times the minimum validation error. We can use the ACEfit.asp_select
function to select the desired models from the result.
asp_final = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, :final)[1]);
asp_size_50 = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, (:bysize, 50))[1]);
asp_error13 = set_parameters!( deepcopy(model),
ACEfit.asp_select(asp_result, (:byerror, 1.3))[1]);
pot_final = fast_evaluator(asp_final; aa_static = false);
pot_50 = fast_evaluator(asp_size_50; aa_static = true);
pot_13 = fast_evaluator(asp_error13; aa_static = true);
err_13 = ACEpotentials.compute_errors(test_data, pot_13);
err_50 = ACEpotentials.compute_errors(test_data, pot_50);
err_fin = ACEpotentials.compute_errors(test_data, pot_final);
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 7.408 │ 0.181 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 7.408 │ 0.181 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 5.341 │ 0.122 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 5.341 │ 0.122 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 10.486 │ 0.205 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 10.486 │ 0.205 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 6.057 │ 0.133 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 6.057 │ 0.133 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 4.140 │ 0.160 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 4.140 │ 0.160 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 3.087 │ 0.107 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 3.087 │ 0.107 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
Similarly, we can compute the errors for the OMP models.
omp_final = set_parameters!( deepcopy(model),
ACEfit.asp_select(omp_result, :final)[1]);
omp_50 = set_parameters!( deepcopy(model),
ACEfit.asp_select(omp_result, (:bysize, 50))[1]);
omp_13 = set_parameters!( deepcopy(model),
ACEfit.asp_select(omp_result, (:byerror, 1.3))[1]);
pot_fin = fast_evaluator(omp_final; aa_static = false);
pot_50 = fast_evaluator(omp_50; aa_static = true);
pot_13 = fast_evaluator(omp_13; aa_static = true);
err_13 = ACEpotentials.compute_errors(test_data, pot_13);
err_50 = ACEpotentials.compute_errors(test_data, pot_50);
err_fin = ACEpotentials.compute_errors(test_data, pot_fin);
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 11.508 │ 0.222 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 11.508 │ 0.222 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 6.279 │ 0.143 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 6.279 │ 0.143 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 10.174 │ 0.186 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 10.174 │ 0.186 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 7.007 │ 0.121 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 7.007 │ 0.121 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: RMSE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 3.670 │ 0.161 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 3.670 │ 0.161 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
[ Info: MAE Table
┌──────┬─────────┬──────────┬─────────┐
│ Type │ E [meV] │ F [eV/A] │ V [meV] │
├──────┼─────────┼──────────┼─────────┤
│ nil │ 2.844 │ 0.107 │ 0.000 │
├──────┼─────────┼──────────┼─────────┤
│ set │ 2.844 │ 0.107 │ 0.000 │
└──────┴─────────┴──────────┴─────────┘
Finally, we can visualize the results along the solution path. We plot the validation error as a function of the number of active parameters for both ASP and OMP.
path_asp = asp_result["path"];
path_omp = omp_result["path"];
nz_counts_asp = [nnz(p.solution) for p in path_asp];
nz_counts_omp = [nnz(p.solution) for p in path_omp];
rmses_asp = [p.rmse for p in path_asp];
rmses_omp = [p.rmse for p in path_omp];
plot(nz_counts_asp, rmses_asp;
xlabel = "# Nonzero Coefficients",
ylabel = "RMSE",
title = "RMSE vs Sparsity Level",
marker = :o,
grid = true, yscale = :log10, label = "ASP")
plot!(nz_counts_omp, rmses_omp; marker = :o, label = "OMP")
This page was generated using Literate.jl.