Skip to main content
pybamm.JaxSolver integrates PyBaMM models using JAX, enabling JIT compilation, GPU/TPU execution, and end-to-end differentiation through the solver.
JaxSolver requires Python ≥ 3.11 and is not supported on Intel-based macOS (darwin/x86_64). Install the jax extra: pip install pybamm[jax].

Constructor

pybamm.JaxSolver(
    method="BDF",
    root_method=None,
    rtol=1e-6,
    atol=1e-6,
    extrap_tol=None,
    on_extrapolation=None,
    extra_options=None,
)
method
str
default:"\"BDF\""
Integration method:
MethodDescription
"BDF"Custom JAX BDF integrator (jax_bdf_integrate). Supports implicit (stiff) systems. Initial conditions are computed internally.
"RK45"Uses jax.experimental.ode.odeint. Explicit; ODE-only.
root_method
str | None
default:"None"
Method for computing consistent initial conditions. None uses the BDF solver’s internal Newton chord method. Can be overridden with any BaseSolver-compatible method.
rtol
float
default:"1e-6"
Relative tolerance.
atol
float
default:"1e-6"
Absolute tolerance.
extrap_tol
float | None
default:"0"
Tolerance for extrapolation detection.
on_extrapolation
str
default:"\"warn\""
Behaviour on extrapolation: "warn", "error", or "ignore".
extra_options
dict | None
default:"{}"
Additional options forwarded to the JAX integrator.

Requirements

RequirementDetail
Python≥ 3.11
PlatformNot darwin/x86_64 (Intel macOS)
JAX version>=0.7.0, <0.9.0
Model formatmodel.convert_to_format must be "jax"
EventsModels must have no termination events

get_solve

Returns a compiled JAX function for repeated fast solves.
solve_fn = solver.get_solve(model, t_eval)
# solve_fn(inputs) -> y
The returned function is JIT-compiled on first call and cached. Subsequent calls with the same t_eval reuse the compiled function.

Examples

import pybamm
import numpy as np

model = pybamm.lithium_ion.SPM(options={"convert_to_format": "jax"})
param = pybamm.ParameterValues("Chen2020")

solver = pybamm.JaxSolver(method="BDF", rtol=1e-6, atol=1e-6)

sim = pybamm.Simulation(model, parameter_values=param, solver=solver)
sol = sim.solve(np.linspace(0, 3600, 100))
When using JaxSolver, the model must be set up with convert_to_format="jax" and must not contain any termination events (e.g. voltage cut-off events). A RuntimeError is raised otherwise.

Build docs developers (and LLMs) love