ETACE Models Tutorial

This tutorial demonstrates how to use the EquivariantTensors (ET) backend for ACE models in ACEpotentials.jl. The ET backend provides:

  • Graph-based evaluation (edge-centric computation)
  • Automatic differentiation via Zygote
  • GPU-ready architecture via KernelAbstractions
  • Lux.jl layer integration

We cover two approaches:

  1. Converting from an existing ACE model - The recommended approach
  2. Creating an ETACE model from scratch - For advanced users
# Load required packages
using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful
using AtomsCalculators, Random, LinearAlgebra

M = ACEpotentials.Models
ETM = ACEpotentials.ETModels
import EquivariantTensors as ET
import Polynomials4ML as P4ML

rng = Random.MersenneTwister(1234)
Random.MersenneTwister(1234)

Part 1: Converting from an Existing ACE Model (Recommended)

The simplest way to get an ETACE model is to convert from a standard ACE model. This approach ensures consistency with the familiar ACE model construction API.

# Define model hyperparameters
elements = (:Si, :O)
order = 3          # correlation order (body-order = order + 1)
max_level = 10     # total polynomial degree
maxl = 6           # maximum angular momentum
rcut = 5.5         # cutoff radius in Angstrom

# Create the standard ACE model
rin0cuts = M._default_rin0cuts(elements)
rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts)

# Note: pair_learnable=true is required for ET conversion
# (default uses splines which aren't yet supported by convert2et)
model = M.ace_model(;
   elements = elements,
   order = order,
   Ytype = :solid,
   level = M.TotalDegree(),
   max_level = max_level,
   maxl = maxl,
   pair_maxn = max_level,
   rin0cuts = rin0cuts,
   E0s = Dict(:Si => -0.846, :O => -1.023),  # reference energies
   pair_learnable = true   # required for ET conversion
)

# Initialize parameters with Lux
ps, st = Lux.setup(rng, model)

@info "Standard ACE model created"
@info "  Number of basis functions: $(M.length_basis(model))"
[ Info: Standard ACE model created
[ Info:   Number of basis functions: 240

Method A: Convert full model (E0 + Pair + Many-body) to StackedCalculator

# convert2et_full creates a StackedCalculator combining:
#   - ETOneBody (reference energies per species)
#   - ETPairModel (pair potential)
#   - ETACE (many-body ACE potential)

et_calc_full = ETM.convert2et_full(model, ps, st; rng=rng)

@info "Full conversion to StackedCalculator"
@info "  Contains: ETOneBody + ETPairPotential + ETACEPotential"
@info "  Total linear parameters: $(ETM.length_basis(et_calc_full))"
[ Info: Full conversion to StackedCalculator
[ Info:   Contains: ETOneBody + ETPairPotential + ETACEPotential
[ Info:   Total linear parameters: 240

Method B: Convert only the many-body ACE component

# convert2et creates just the ETACE model (many-body only, no E0 or pair)
et_ace = ETM.convert2et(model)
et_ace_ps, et_ace_st = Lux.setup(rng, et_ace)

# Copy parameters from the original model
ETM.copy_ace_params!(et_ace_ps, ps, model)

# Wrap in calculator for AtomsCalculators interface
et_ace_calc = ETM.ETACEPotential(et_ace, et_ace_ps, et_ace_st, rcut)

@info "Many-body only conversion"
@info "  ETACE basis size: $(ETM.length_basis(et_ace_calc))"
[ Info: Many-body only conversion
[ Info:   ETACE basis size: 220

Method C: Convert only the pair potential

# convertpair creates an ETPairModel
et_pair = ETM.convertpair(model)
et_pair_ps, et_pair_st = Lux.setup(rng, et_pair)

# Copy parameters from the original model
ETM.copy_pair_params!(et_pair_ps, ps, model)

# Wrap in calculator
et_pair_calc = ETM.ETPairPotential(et_pair, et_pair_ps, et_pair_st, rcut)

@info "Pair potential only conversion"
@info "  ETPairModel basis size: $(ETM.length_basis(et_pair_calc))"
[ Info: Pair potential only conversion
[ Info:   ETPairModel basis size: 20

Part 2: Using ETACE Calculators

# Create a test system
sys = AtomsBuilder.bulk(:Si) * (2, 2, 1)
rattle!(sys, 0.1u"Å")
AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5])

@info "Test system: $(length(sys)) atoms"

# Evaluate energy, forces, virial using AtomsCalculators interface
E = AtomsCalculators.potential_energy(sys, et_calc_full)
F = AtomsCalculators.forces(sys, et_calc_full)
V = AtomsCalculators.virial(sys, et_calc_full)

@info "Energy evaluation with full ETACE calculator"
@info "  Energy: $E"
@info "  Max force magnitude: $(maximum(norm.(F)))"

# Combined evaluation (more efficient)
efv = AtomsCalculators.energy_forces_virial(sys, et_calc_full)
@info "  Combined EFV evaluation successful"
[ Info: Test system: 8 atoms
[ Info: Energy evaluation with full ETACE calculator
[ Info:   Energy: -7.652999999999999 eV
[ Info:   Max force magnitude: 0.0 eV Å^-1
[ Info:   Combined EFV evaluation successful

Part 3: Training Assembly (for Linear Fitting)

The ETACE calculators support training assembly functions for ACEfit integration. These compute the design matrix rows for linear least squares fitting.

# Energy-only basis evaluation (fastest)
E_basis = ETM.potential_energy_basis(sys, et_ace_calc)
@info "Energy basis: $(length(E_basis)) components"

# Full energy, forces, virial basis
efv_basis = ETM.energy_forces_virial_basis(sys, et_ace_calc)
@info "EFV basis shapes:"
@info "  Energy:  $(size(efv_basis.energy))"
@info "  Forces:  $(size(efv_basis.forces))"
@info "  Virial:  $(size(efv_basis.virial))"

# Get/set linear parameters
params = ETM.get_linear_parameters(et_ace_calc)
@info "Linear parameters: $(length(params)) values"

# Parameters can be updated for fitting:
# ETM.set_linear_parameters!(et_ace_calc, new_params)
[ Info: Energy basis: 220 components
[ Info: EFV basis shapes:
[ Info:   Energy:  (220,)
[ Info:   Forces:  (8, 220)
[ Info:   Virial:  (220,)
[ Info: Linear parameters: 220 values

Part 4: Creating an ETACE Model from Scratch (Advanced)

For advanced users who want direct control over the model architecture. This requires understanding the EquivariantTensors.jl API.

# Define model parameters
scratch_elements = [:Si, :O]
scratch_maxn = 6      # number of radial basis functions
scratch_maxl = 4      # maximum angular momentum
scratch_order = 2     # correlation order
scratch_rcut = 5.5    # cutoff radius

# Species information
zlist = ChemicalSpecies.(scratch_elements)
NZ = length(zlist)
2

Build the radial embedding (Rnl)

# Radial specification (n, l pairs)
Rnl_spec = [(n=n, l=l) for n in 1:scratch_maxn for l in 0:scratch_maxl]

# Distance transform: r -> transformed coordinate y
# Using standard Agnesi transform parameters
f_trans = let rcut = scratch_rcut
   (x, st) -> begin
      r = norm(x.𝐫)
      # Simple polynomial transform (normalized to [-1, 1])
      y = 1 - 2 * r / rcut
      return y
   end
end
trans = ET.NTtransformST(f_trans, NamedTuple())

# Envelope function: smooth cutoff
f_env = y -> (1 - y^2)^2  # quartic envelope

# Polynomial basis (Chebyshev)
polys = P4ML.ChebBasis(scratch_maxn)
Penv = P4ML.wrapped_basis(Lux.BranchLayer(
   polys,
   Lux.WrappedFunction(y -> f_env.(y)),
   fusion = Lux.WrappedFunction(Pe -> Pe[2] .* Pe[1])
))

# Species-pair selector for radial weights
selector_ij = let zlist = tuple(zlist...)
   xij -> ET.catcat2idx(zlist, xij.z0, xij.z1)
end

# Linear layer: P(yij) -> W[(Zi, Zj)] * P(yij)
linl = ET.SelectLinL(scratch_maxn, length(Rnl_spec), NZ^2, selector_ij)

# Complete radial embedding
rbasis = ET.EmbedDP(trans, Penv, linl)
rembed = ET.EdgeEmbed(rbasis)
EdgeEmbed(
    layer = EmbedDP(
        trans = NTtransformST(),
        basis = Polynomials4ML.WrappedBasis{Lux.BranchLayer{@NamedTuple{layer_1::Polynomials4ML.ChebBasis{6}, layer_2::Lux.WrappedFunction{Main.var"#18#19"}}, Lux.WrappedFunction{Main.var"#20#21"}, Nothing}}(Lux.BranchLayer{@NamedTuple{layer_1::Polynomials4ML.ChebBasis{6}, layer_2::Lux.WrappedFunction{Main.var"#18#19"}}, Lux.WrappedFunction{Main.var"#20#21"}, Nothing}((layer_1 = ChebBasis(6), layer_2 = WrappedFunction(#18)), WrappedFunction(#20), nothing), 6),
        post = EquivariantTensors.SelectLinL{Main.var"#23#24"{Tuple{AtomsBase.ChemicalSpecies, AtomsBase.ChemicalSpecies}}}(6, 30, 4, Main.var"#23#24"{Tuple{AtomsBase.ChemicalSpecies, AtomsBase.ChemicalSpecies}}((Si, O))),  # 720 parameters
    ),
)         # Total: 720 parameters,
          #        plus 0 states.

Build the angular embedding (Ylm)

# Spherical harmonics basis
ylm_basis = P4ML.real_sphericalharmonics(scratch_maxl)
Ylm_spec = P4ML.natural_indices(ylm_basis)

# Angular embedding: edge direction -> spherical harmonics
ybasis = ET.EmbedDP(
   ET.NTtransformST((x, st) -> x.𝐫, NamedTuple()),
   ylm_basis
)
yembed = ET.EdgeEmbed(ybasis)
EdgeEmbed(
    layer = EmbedDP(
        trans = NTtransformST(),
        basis = SphericalHarmonics(ℝ, maxl=4),
        post = EquivariantTensors.IDpost(),
    ),
)         # Total: 0 parameters,
          #        plus 25 states.

Build the many-body basis (sparse ACE)

# Define the many-body specification
# This specifies which (n,l) combinations appear in each correlation
# For simplicity, use all 1-correlations up to given degree
mb_spec = [[(n=n, l=l)] for n in 1:scratch_maxn for l in 0:scratch_maxl]

# Create sparse equivariant tensor (ACE basis)
mb_basis = ET.sparse_equivariant_tensor(
   L = 0,                # scalar (invariant) output
   mb_spec = mb_spec,
   Rnl_spec = Rnl_spec,
   Ylm_spec = Ylm_spec,
   basis = real          # real-valued basis
)
SparseACEbasis(L = (0,))

Build the readout layer

# Species selector for readout
selector_i = let zlist = zlist
   x -> ET.cat2idx(zlist, x.z)
end

# Readout: basis values -> site energies
readout = ET.SelectLinL(
   mb_basis.lens[1],     # input dimension (basis length)
   1,                    # output dimension (site energy)
   NZ,                   # number of species categories
   selector_i
)
EquivariantTensors.SelectLinL{Main.var"#35#36"{Vector{AtomsBase.ChemicalSpecies}}}(6, 1, 2, Main.var"#35#36"{Vector{AtomsBase.ChemicalSpecies}}(AtomsBase.ChemicalSpecies[Si, O]))  # 12 parameters

Assemble the ETACE model

scratch_etace = ETM.ETACE(rembed, yembed, mb_basis, readout)

# Initialize with Lux
scratch_ps, scratch_st = Lux.setup(rng, scratch_etace)

@info "ETACE model created from scratch"
@info "  Radial basis size: $(length(Rnl_spec))"
@info "  Angular basis size: $(length(Ylm_spec))"
@info "  Many-body basis size: $(mb_basis.lens[1])"

# Wrap in calculator
scratch_calc = ETM.ETACEPotential(scratch_etace, scratch_ps, scratch_st, scratch_rcut)

# Test evaluation
E_scratch = AtomsCalculators.potential_energy(sys, scratch_calc)
@info "Scratch model energy: $E_scratch"
[ Info: ETACE model created from scratch
[ Info:   Radial basis size: 30
[ Info:   Angular basis size: 25
[ Info:   Many-body basis size: 6
[ Info: Scratch model energy: -5.882066501882554 eV

Part 5: Creating One-Body and Pair Models from Scratch

ETOneBody: Reference energies

# Define reference energies per species
E0_dict = Dict(ChemicalSpecies(:Si) => -0.846,
               ChemicalSpecies(:O) => -1.023)

# Category function extracts species from atom state
catfun = x -> x.z  # x.z is the ChemicalSpecies

# Create one-body model
et_onebody = ETM.one_body(E0_dict, catfun)
_, onebody_st = Lux.setup(rng, et_onebody)

# Wrap in calculator (uses small cutoff since no neighbors needed)
onebody_calc = ETM.ETOneBodyPotential(et_onebody, nothing, onebody_st, 3.0)

@info "ETOneBody model created"
@info "  Reference energies: $E0_dict"

E_onebody = AtomsCalculators.potential_energy(sys, onebody_calc)
@info "  One-body energy for test system: $E_onebody"
[ Info: ETOneBody model created
[ Info:   Reference energies: Dict{AtomsBase.ChemicalSpecies, Float64}(Si => -0.846, O => -1.023)
[ Info:   One-body energy for test system: -7.652999999999999 eV

Part 6: Combining Models with StackedCalculator

StackedCalculator combines multiple calculators by summing their contributions.

# Stack our from-scratch models
combined_calc = ETM.StackedCalculator((onebody_calc, scratch_calc))

@info "StackedCalculator created"
@info "  Components: ETOneBody + ETACE"
@info "  Total basis size: $(ETM.length_basis(combined_calc))"

# Evaluate combined model
E_combined = AtomsCalculators.potential_energy(sys, combined_calc)
@info "  Combined energy: $E_combined"

# Training assembly works on StackedCalculator too
efv_combined = ETM.energy_forces_virial_basis(sys, combined_calc)
@info "  Combined EFV basis shapes: E=$(size(efv_combined.energy)), F=$(size(efv_combined.forces))"

@info "Tutorial complete!"
[ Info: StackedCalculator created
[ Info:   Components: ETOneBody + ETACE
[ Info:   Total basis size: 12
[ Info:   Combined energy: -13.535066501882554 eV
[ Info:   Combined EFV basis shapes: E=(12,), F=(8, 12)
[ Info: Tutorial complete!

This page was generated using Literate.jl.