Predicting with the trained emulators¶
This notebook calls the registered haloemu emulators for a cosmology θ. Every
prediction is one line:
reg.predict(property, gravity, redshift, theta)
The trained artifacts ship inside the package, so this runs on a plain
pip install with no halocat/freyja and no simulation data (see the
Portability page). A single prediction is milliseconds.
import os
os.environ.setdefault("JAX_PLATFORMS", "cpu") # silence the harmless CUDA probe
import numpy as np
import matplotlib.pyplot as plt
from haloemu import get_registry
reg = get_registry()
print("haloemu registry loaded")
haloemu registry loaded
What is registered¶
reg.list() returns one entry per (property, gravity, redshift) key. The
matter sector (hmf, pk_mm, xi_mm) is registered at z = 0.25 and
z = 0.00; everything else at z = 0.25. Both gravities LCDM and fRn1.
entries = reg.list()
print(f"{len(entries)} artifacts\n")
print(f"{'key':<24}{'kind':<9}{'class'}")
for e in entries:
cls = e.get("artifact_class", "recipe").split(".")[-1]
print(f"{e['key']:<24}{e['kind']:<9}{cls}")
38 artifacts key kind class b_cum/LCDM/z0.25 trained PeakHeightEmulator b_cum/fRn1/z0.25 trained PeakHeightEmulator b_diff/LCDM/z0.25 derived recipe b_diff/fRn1/z0.25 derived recipe hmf/LCDM/z0.00 trained SavedEmulator hmf/LCDM/z0.25 trained SavedEmulator hmf/fRn1/z0.00 trained MGBoostEmulator hmf/fRn1/z0.25 trained MGBoostEmulator pk_mm/LCDM/z0.00 trained PkMMEmulator pk_mm/LCDM/z0.25 trained PkMMEmulator pk_mm/fRn1/z0.00 trained MGBoostEmulator pk_mm/fRn1/z0.25 trained MGBoostEmulator r_ab/LCDM/z0.25 trained RAbEmulator r_ab/fRn1/z0.25 trained RAbEmulator vel_c02/LCDM/z0.25 trained VelMomentEmulator vel_c02/fRn1/z0.25 trained VelMomentBoostEmulator vel_c04/LCDM/z0.25 trained VelMomentEmulator vel_c04/fRn1/z0.25 trained VelMomentBoostEmulator vel_c12/LCDM/z0.25 trained VelMomentEmulator vel_c12/fRn1/z0.25 trained VelMomentBoostEmulator vel_c20/LCDM/z0.25 trained VelMomentEmulator vel_c20/fRn1/z0.25 trained VelMomentBoostEmulator vel_c22/LCDM/z0.25 trained VelMomentEmulator vel_c22/fRn1/z0.25 trained VelMomentBoostEmulator vel_c30/LCDM/z0.25 trained VelMomentEmulator vel_c30/fRn1/z0.25 trained VelMomentBoostEmulator vel_c40/LCDM/z0.25 trained VelMomentEmulator vel_c40/fRn1/z0.25 trained VelMomentBoostEmulator vel_m10/LCDM/z0.25 trained VelMomentEmulator vel_m10/fRn1/z0.25 trained VelMomentEmulator xi_hh/LCDM/z0.25 derived recipe xi_hh/fRn1/z0.25 derived recipe xi_hh_smallr/LCDM/z0.25 trained XiABSmallREmulator xi_hh_smallr/fRn1/z0.25 trained XiABSmallREmulator xi_mm/LCDM/z0.00 trained XiMMEmulator xi_mm/LCDM/z0.25 trained XiMMEmulator xi_mm/fRn1/z0.00 trained MGBoostEmulator xi_mm/fRn1/z0.25 trained MGBoostEmulator
θ conventions¶
| Gravity | theta |
|---|---|
LCDM |
[Omega_m, h, n_s, S_8] |
fRn1 |
[Omega_m, h, n_s, S_8, logf_R0] |
Pass S_8, not sigma_8 — S_8 = sigma_8 * sqrt(Omega_m / 0.3). θ must lie
inside the design hull (Omega_m ∈ [0.15, 0.445], h ∈ [0.596, 0.795],
n_s ∈ [0.94, 0.989], S_8 ∈ [0.65, 0.945], logf_R0 ∈ [−7, −4.05]); outside is
unvalidated extrapolation.
th = [0.31, 0.677, 0.967, 0.83] # LCDM
th5 = [0.31, 0.677, 0.967, 0.83, -5.0] # fRn1 (logf_R0 = -5)
n = reg.predict("hmf", "LCDM", 0.25, th)
pk = reg.predict("pk_mm", "LCDM", 0.25, th)
xi = reg.predict("xi_mm", "LCDM", 0.25, th)
b = reg.predict("b_cum", "LCDM", 0.25, th)
for name, arr in [("hmf", n), ("pk_mm", pk), ("xi_mm", xi), ("b_cum", b)]:
print(f"{name:7s} -> shape {np.asarray(arr).shape}")
ERROR:2026-06-23 21:15:03,211:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
File "/cosma/apps/durham/dc-ruan1/micromamba/envs/cosemu/lib/python3.14/site-packages/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
plugin_module.initialize()
~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/cosma/apps/durham/dc-ruan1/micromamba/envs/cosemu/lib/python3.14/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
_check_cuda_versions(raise_on_first_error=True)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/cosma/apps/durham/dc-ruan1/micromamba/envs/cosemu/lib/python3.14/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
local_device_count = cuda_versions.cuda_device_count()
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE
hmf -> shape (1, 23) pk_mm -> shape (1, 147) xi_mm -> shape (1, 113) b_cum -> shape (1, 30)
Coordinates and a first plot¶
A trained artifact exposes its coordinate grid as art.coord (log10M edges,
k in h/Mpc, or r in Mpc/h). Let's plot the mass function and the matter
power spectrum.
hmf_art = reg.load("hmf", "LCDM", 0.25)
pk_art = reg.load("pk_mm", "LCDM", 0.25)
logM = np.asarray(hmf_art.coord)
k = np.asarray(pk_art.coord)
fig, (axL, axR) = plt.subplots(1, 2, figsize=(11, 4))
axL.semilogy(logM, n[0], "o-", ms=3)
axL.set_xlabel(r"$\log_{10} M\ [M_\odot/h]$")
axL.set_ylabel(r"$n(>M)\ [(\mathrm{Mpc}/h)^{-3}]$")
axL.set_title("Cumulative HMF (LCDM, z=0.25)")
axR.loglog(k, pk[0])
axR.set_xlabel(r"$k\ [h/\mathrm{Mpc}]$")
axR.set_ylabel(r"$P_{mm}(k)\ [(\mathrm{Mpc}/h)^3]$")
axR.set_title("Matter power spectrum")
fig.tight_layout()
plt.show()
f(R) as a boost over ΛCDM¶
The fRn1 matter artifacts are seed-paired boosts composed onto the pinned
ΛCDM base. Predict both at the same background θ and take the ratio to see the
modified-gravity enhancement.
pk_lcdm = reg.predict("pk_mm", "LCDM", 0.25, th)
pk_fr = reg.predict("pk_mm", "fRn1", 0.25, th5)
ratio = pk_fr[0] / pk_lcdm[0]
fig, ax = plt.subplots(figsize=(6, 4))
ax.semilogx(k, ratio)
ax.axhline(1.0, color="k", lw=0.7, ls="--")
ax.set_xlabel(r"$k\ [h/\mathrm{Mpc}]$")
ax.set_ylabel(r"$P_{mm}^{f(R)} / P_{mm}^{\Lambda\mathrm{CDM}}$")
ax.set_title(r"f(R) boost, $\log_{10}|f_{R0}| = -5$")
fig.tight_layout()
plt.show()
print("max enhancement:", float(ratio.max()))
max enhancement: 1.3015154550847226
Velocity moments (surface artifacts)¶
The vel_* moments return a (1, 10, 60) surface — 10 halo mass-bin pairs ×
60 radial bins. Axes are on art.pair_keys and art.r.
m10_art = reg.load("vel_m10", "LCDM", 0.25)
r = np.asarray(m10_art.r)
m10 = reg.predict("vel_m10", "LCDM", 0.25, th) # (1, 10, 60)
print("pairs:", list(m10_art.pair_keys))
ip = 0 # first mass-bin pair
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(r, m10[0, ip])
ax.axhline(0.0, color="k", lw=0.7, ls="--")
ax.set_xlabel(r"$r\ [\mathrm{Mpc}/h]$")
ax.set_ylabel(r"$m_{10}(r)$ (mean pairwise infall)")
ax.set_title(f"vel_m10, mass pair {m10_art.pair_keys[ip]}")
fig.tight_layout()
plt.show()
pairs: [(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
Derived properties: b_diff and xi_hh¶
These are recipes — the registry resolves their upstream artifacts and returns a
dict. They predict standalone too (n̄ for b_diff comes from the in-suite hmf
emulator).
bd = reg.predict("b_diff", "LCDM", 0.25, th)
print("b_diff keys:", [kk for kk in bd if not kk.startswith('_')][:8])
print("checks:", {kk: bd['checks'][kk] for kk in ('b_le_bcum','monotonic','boundary_ok')})
print("nbar source:", bd.get("nbar_source"))
xhh = reg.predict("xi_hh", "LCDM", 0.25, th)
print("\nxi_hh shape:", xhh["xi_hh"].shape,
"| r range [%.2f, %.1f]" % (xhh["r"][0], xhh["r"][-1]))
print("per-bin <b>:", np.round(xhh["b_bin"], 3))
b_diff keys: ['theta', 'theta_keys', 'redshift', 'x_grid', 'support', 'extrapolated', 'b_diff', 'b_diff_xform']
checks: {'b_le_bcum': True, 'monotonic': True, 'boundary_ok': True}
nbar source: haloemu_hmf_artifact
xi_hh shape: (15, 15, 113) | r range [2.04, 124.8] per-bin <b>: [1.039 1.086 1.137 1.187 1.256 1.332 1.415 1.512 1.624 1.705 1.81 1.969 2.117 2.308 2.512]
Predictive variance¶
For trained artifacts, return_var=True returns (mean, variance).
mean, var = reg.predict("hmf", "LCDM", 0.25, th, return_var=True)
rel = np.sqrt(var[0]) / mean[0]
print("median GP relative sigma on n(>M): %.2e" % np.median(rel))
median GP relative sigma on n(>M): 4.63e-04
Next steps¶
- Calling conventions, shapes, and kwargs per property: the Predicting guide and the Machine reference.
- Accuracy and caveats (frozen-seed offset, BAO mask): Accuracy & gates, Caveats.
- To fit your own artifacts, see the Training new emulators notebook.