JAX based models

JAX models are notoriously difficult to deploy with C++. example/jax_example demonstrates how to wrap a JAX MD model and expose it to KIM-API simulators through KUSP. The workflow matches the Lennard-Jones and NequIP tutorials, so you can swap components without relearning new tooling.

1. Decorate the JAX model

  • JAXSiSW.py contains a @kusp_model entry point that wires up the JAX MD potential, handles the JIT-compiled energy/force function, and returns NumPy arrays.

  • Adjust cutoffs, species, or model weights there if you change the physics.

The wrapper is intentionally tiny, most of the file is standard JAX MD setup:

import jax
import jax.numpy as jnp
import numpy as np
from kusp import kusp_model


@kusp_model(influence_distance=3.2, species=("Si",))
class JAXSiSW:
    def __init__(self):
        self.energy_force = jax.jit(self._build_energy_force())

    def __call__(self, species, positions, contributing):
        energy, forces = self.energy_force(
            jnp.array(positions), jnp.array(contributing, dtype=jnp.bool_)
        )
        return np.asarray(energy), np.asarray(forces)

Everything outside the decorator/block focuses on defining _build_energy_force with JAX MD primitives. No simulator-specific logic is required.

2. Serve with hot reload

kusp serve example/jax_example/JAXSiSW.py \
    --kusp-config example/jax_example/kusp_config.yaml
  • The generated config path is printed once; export it via export KUSP_CONFIG=$PWD/example/jax_example/kusp_config.yaml.

  • Save JAXSiSW.py and press Ctrl+C once to reload JIT-compiled functions without restarting.

  • Press Ctrl+C twice quickly to stop the server.

3. Validate the wrapper

  • eval_jax_md.py compares energies/forces from the native JAX MD call to those returned by the TCP server, so you can confirm the decorated model stays bitwise consistent.

  • Si.xyz provides the sample configurations used by the script, but you can replace it with your own trajectories.

4. Package for redistribution

kusp export example/jax_example/JAXSiSW.py \
    -n KUSP_JAXSiSW__MO_111111111111_000 \
    --env pip

The output directory contains the hashed Python module, optional resources, an environment manifest, and the CMakeLists.txt pointing at KUSP__MD_000000000000_000. Install it directly with kim-api-collections-management install or share the folder as-is.