pybamm.JaxSolver integrates PyBaMM models using JAX, enabling JIT compilation, GPU/TPU execution, and end-to-end differentiation through the solver.
Constructor
Integration method:
| Method | Description |
|---|---|
"BDF" | Custom JAX BDF integrator (jax_bdf_integrate). Supports implicit (stiff) systems. Initial conditions are computed internally. |
"RK45" | Uses jax.experimental.ode.odeint. Explicit; ODE-only. |
Method for computing consistent initial conditions.
None uses the BDF solver’s internal Newton chord method. Can be overridden with any BaseSolver-compatible method.Relative tolerance.
Absolute tolerance.
Tolerance for extrapolation detection.
Behaviour on extrapolation:
"warn", "error", or "ignore".Additional options forwarded to the JAX integrator.
Requirements
| Requirement | Detail |
|---|---|
| Python | ≥ 3.11 |
| Platform | Not darwin/x86_64 (Intel macOS) |
| JAX version | >=0.7.0, <0.9.0 |
| Model format | model.convert_to_format must be "jax" |
| Events | Models must have no termination events |
get_solve
Returns a compiled JAX function for repeated fast solves.
t_eval reuse the compiled function.
Examples
When using
JaxSolver, the model must be set up with convert_to_format="jax" and must not contain any termination events (e.g. voltage cut-off events). A RuntimeError is raised otherwise.