Committee Potentials

using Plots, ACEpotentials, Statistics

Perform the fit

load some example training data

train, _, _ = ACEpotentials.example_dataset("Si_tiny")
data_keys = (energy_key = "dft_energy", force_key = "dft_force");
 Downloading artifact: Si_tiny_dataset
     Failure artifact: Si_tiny_dataset
 Downloading artifact: Si_tiny_dataset

create model

model = acemodel(elements = [:Si,], order = 3, totaldegree = 8);

create solver, setting a nonzero committee size at present, the SVD factorization is required for committees

solver = ACEfit.BLR(committee_size=10, factorization=:svd);

perform the fit

acefit!(model, train; solver=solver, data_keys...);
┌───────────────┬──────────┬───────┬────┬─────┬─────┐
│          Type  #Configs  #Envs  #E   #F   #V │
├───────────────┼──────────┼───────┼────┼─────┼─────┤
│ isolated_atom │        1 │     1 │  1 │   3 │   0 │
│           dia │       25 │    50 │ 25 │ 150 │   0 │
│            bt │       25 │    50 │ 25 │ 150 │   0 │
│           liq │        2 │   128 │  2 │ 384 │   0 │
├───────────────┼──────────┼───────┼────┼─────┼─────┤
│         total │       53 │   229 │ 53 │ 687 │   0 │
│       missing │        0 │     0 │  0 │   0 │ 318 │
└───────────────┴──────────┴───────┴────┴─────┴─────┘
[ Info: Assembling linear problem.
[ Info:   - Creating feature matrix with size (740, 54).
[ Info:   - Beginning assembly with processor count:  1.

Progress:   4%|█▌                                       |  ETA: 0:00:44
Progress:   6%|██▍                                      |  ETA: 0:00:33
Progress:   8%|███▏                                     |  ETA: 0:00:28
Progress:   9%|███▉                                     |  ETA: 0:00:24
Progress:  11%|████▋                                    |  ETA: 0:00:22
Progress:  13%|█████▍                                   |  ETA: 0:00:20
Progress:  15%|██████▎                                  |  ETA: 0:00:19
Progress:  17%|███████                                  |  ETA: 0:00:18
Progress:  19%|███████▊                                 |  ETA: 0:00:17
Progress:  21%|████████▌                                |  ETA: 0:00:16
Progress:  23%|█████████▎                               |  ETA: 0:00:15
Progress:  25%|██████████                               |  ETA: 0:00:14
Progress:  26%|██████████▉                              |  ETA: 0:00:14
Progress:  28%|███████████▋                             |  ETA: 0:00:13
Progress:  30%|████████████▍                            |  ETA: 0:00:13
Progress:  32%|█████████████▏                           |  ETA: 0:00:12
Progress:  34%|█████████████▉                           |  ETA: 0:00:12
Progress:  36%|██████████████▊                          |  ETA: 0:00:11
Progress:  38%|███████████████▌                         |  ETA: 0:00:11
Progress:  40%|████████████████▎                        |  ETA: 0:00:10
Progress:  42%|█████████████████                        |  ETA: 0:00:10
Progress:  43%|█████████████████▊                       |  ETA: 0:00:10
Progress:  45%|██████████████████▋                      |  ETA: 0:00:09
Progress:  47%|███████████████████▍                     |  ETA: 0:00:09
Progress:  49%|████████████████████▏                    |  ETA: 0:00:08
Progress:  51%|████████████████████▉                    |  ETA: 0:00:08
Progress:  53%|█████████████████████▋                   |  ETA: 0:00:08
Progress:  55%|██████████████████████▍                  |  ETA: 0:00:07
Progress:  57%|███████████████████████▎                 |  ETA: 0:00:07
Progress:  58%|████████████████████████                 |  ETA: 0:00:07
Progress:  60%|████████████████████████▊                |  ETA: 0:00:06
Progress:  62%|█████████████████████████▌               |  ETA: 0:00:06
Progress:  64%|██████████████████████████▎              |  ETA: 0:00:06
Progress:  66%|███████████████████████████▏             |  ETA: 0:00:05
Progress:  68%|███████████████████████████▉             |  ETA: 0:00:05
Progress:  70%|████████████████████████████▋            |  ETA: 0:00:05
Progress:  72%|█████████████████████████████▍           |  ETA: 0:00:04
Progress:  74%|██████████████████████████████▏          |  ETA: 0:00:04
Progress:  75%|███████████████████████████████          |  ETA: 0:00:04
Progress:  77%|███████████████████████████████▊         |  ETA: 0:00:04
Progress:  79%|████████████████████████████████▌        |  ETA: 0:00:03
Progress:  81%|█████████████████████████████████▎       |  ETA: 0:00:03
Progress:  83%|██████████████████████████████████       |  ETA: 0:00:03
Progress:  85%|██████████████████████████████████▊      |  ETA: 0:00:02
Progress:  87%|███████████████████████████████████▋     |  ETA: 0:00:02
Progress:  89%|████████████████████████████████████▍    |  ETA: 0:00:02
Progress:  91%|█████████████████████████████████████▏   |  ETA: 0:00:01
Progress:  92%|█████████████████████████████████████▉   |  ETA: 0:00:01
Progress:  94%|██████████████████████████████████████▋  |  ETA: 0:00:01
Progress:  96%|███████████████████████████████████████▌ |  ETA: 0:00:01
Progress:  98%|████████████████████████████████████████▎|  ETA: 0:00:00
Progress: 100%|█████████████████████████████████████████| Time: 0:00:15
[ Info:   - Assembly completed.
[ Info: Assembling full weight vector.
[ Info: Entering bayesian_linear_regression_svd
┌ Info: Computing SVD of (740, 54) matrix
  BLAS.get_num_threads() = 2
  BLAS.get_config() =
   LinearAlgebra.BLAS.LBTConfig
   Libraries:
   └ [ILP64] libopenblas64_.so
[ Info: SVD completed after 3.343613333333334e-5 minutes
[ Info: Beginning to maximize marginal likelihood
Iter     Function value   Gradient norm
     0     1.155681e+07     2.291558e+07
 * time: 0.015307903289794922
     1     1.323055e+04     3.211203e-05
 * time: 0.8638138771057129
     2     1.104783e+04     6.894847e-04
 * time: 0.864084005355835
     3     1.085204e+04     9.054319e-04
 * time: 0.8641810417175293
     4     9.406472e+03     7.122083e-03
 * time: 0.8642370700836182
     5     8.339838e+03     3.316252e-02
 * time: 0.8643250465393066
     6     8.060072e+03     4.969368e-02
 * time: 0.8643879890441895
     7     5.440921e+03     1.551345e+00
 * time: 0.864508867263794
     8     5.433645e+03     7.441536e-01
 * time: 0.8646199703216553
     9     5.424063e+03     2.001076e-01
 * time: 0.864670991897583
    10     5.423534e+03     5.746550e-03
 * time: 0.8647139072418213
    11     5.423533e+03     2.605396e-03
 * time: 0.8647639751434326
    12     5.367192e+03     3.472664e+00
 * time: 0.8648710250854492
    13     5.297705e+03     2.489120e+00
 * time: 0.8649449348449707
    14     5.275075e+03     5.514594e-01
 * time: 0.8649890422821045
    15     5.255304e+03     2.075148e-01
 * time: 0.865062952041626
    16     5.160334e+03     1.326801e-01
 * time: 0.8651199340820312
    17     5.141369e+03     1.112467e-01
 * time: 0.8651750087738037
    18     5.100246e+03     2.168017e-02
 * time: 0.8652310371398926
    19     5.049360e+03     3.610391e-02
 * time: 0.8652749061584473
    20     5.006103e+03     6.247583e-02
 * time: 0.8653428554534912
    21     4.995258e+03     5.149655e-02
 * time: 0.865386962890625
    22     4.984123e+03     9.519155e-02
 * time: 0.8654310703277588
    23     4.964259e+03     1.507008e+00
 * time: 0.8654999732971191
    24     4.963605e+03     3.289420e-01
 * time: 0.8655569553375244
    25     4.937312e+03     1.106698e+00
 * time: 0.8656010627746582
    26     4.935445e+03     9.423114e-01
 * time: 0.8656809329986572
    27     4.926953e+03     2.032554e-01
 * time: 0.8657388687133789
    28     4.925570e+03     5.163711e-02
 * time: 0.8657879829406738
    29     4.925349e+03     2.696016e-02
 * time: 0.8658318519592285
    30     4.925294e+03     2.793348e-03
 * time: 0.8658759593963623
    31     4.925294e+03     5.575531e-04
 * time: 0.8659238815307617
    32     4.925294e+03     4.144730e-07
 * time: 0.8659729957580566
    33     4.925294e+03     6.542649e-10
 * time: 0.8660180568695068
    34     4.925294e+03     6.225166e-16
 * time: 0.8660628795623779
    35     4.925294e+03     6.225166e-16
 * time: 0.8661189079284668
┌ Info: Optimization complete
  Results =
    * Status: success

    * Candidate solution
       Final objective value:     4.925294e+03

    * Found with
       Algorithm:     L-BFGS

    * Convergence measures
       |x - x'|               = 7.11e-15 ≤ 1.0e-08
       |x - x'|/|x'|          = 3.96e-17 ≰ 0.0e+00
       |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
       |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
       |g(x)|                 = 6.23e-16 ≰ 0.0e+00

    * Work counters
       Seconds run:   1  (vs limit Inf)
       Iterations:    35
       f(x) calls:    185
       ∇f(x) calls:   185

Inspect the total energies vs committee energies and error bars for a few perturbed structures. Note the training set is so small that we don't expect these committees to be particularly useful; this is only to illustrate how they might be used. Also note that the energy E is not in general the mean of co_E but it is the mean of the exact posterior distribution.

atoms = bulk(:Si, cubic=true) * 2; rattle = [0.03, 0.1, 0.3]
plot(; size = (300, 300), xlabel = "rattle", ylabel = "energy [eV]", ylims = (-10650, -10250),
      xlims = (0.015, 0.6), xticks = (rattle, string.(rattle)), xscale = :log10)
for (i, rt) in enumerate(rattle)
   rattle!(atoms, rt)
   E, co_E = ACE1.co_energy(model.potential, atoms)
   scatter!(rt*ones(10), co_E, c = 1, label=(i==1 ? "committee" : ""))
   scatter!([rt,], [E,], yerror = [std(co_E),], c = 2, ms=6, label=(i==1 ? "mean" : ""))
end
plot!()
Example block output

Committee forces are computed analogously. F is a vector of mean forces (i.e. a vector of 3-vectors), while co_F is a list of vectors of committe forces (i.e. a vector of vectors of 3-vectors).

F, co_F = ACE1.co_forces(model.potential, atoms)
@show typeof(F)
@show typeof(co_F);
typeof(F) = Vector{StaticArraysCore.SVector{3, Float64}}
typeof(co_F) = StaticArraysCore.SVector{10, Vector{StaticArraysCore.SVector{3, Float64}}}

The situation is analogous for committee virials

V, co_V = ACE1.co_virial(model.potential, atoms)
@show typeof(V)
@show typeof(co_V);
typeof(V) = StaticArraysCore.SMatrix{3, 3, Float64, 9}
typeof(co_V) = StaticArraysCore.SVector{10, StaticArraysCore.SMatrix{3, 3, Float64, 9}}

This page was generated using Literate.jl.