Explore symbolic, numerical, and automatic differentiation methods in Python using SymPy, NumPy, and JAX libraries. Compare their computational efficiency for machine learning applications.
This guide explores three approaches to computing derivatives in Python: symbolic differentiation with SymPy, numerical differentiation with NumPy, and automatic differentiation with JAX. Understanding these methods is crucial for implementing machine learning algorithms efficiently.
Symbolic differentiation struggles with discontinuous functions and can produce overly complex expressions (expression swell), leading to inefficient computation.
def f_composed(x): return np.exp(-2*x) + 3*np.sin(3*x)# Numerical derivative works regardless of function complexitydfdx_num = np.gradient(f_composed(x_array_2), x_array_2)
Numerical differentiation doesn’t require knowledge of the function’s formula—only its values. This makes it versatile but less accurate than symbolic methods.
Inaccuracy at discontinuities: Similar to symbolic methods
Computational expense: Requires function evaluation at multiple points
Not exact: Provides approximations with potential rounding errors
Numerical differentiation is slow for machine learning because it requires full function evaluation for each parameter—impractical for models with hundreds or thousands of parameters.
For polynomial compositions, automatic differentiation excels:
def f_polynomial_simple(x): return 2*x**3 - 3*x**2 + 5def f_polynomial(x): for i in range(3): x = f_polynomial_simple(x) return x# Compare symbolic vs automatictic_symb = time.time()res_symb = lambdify(x, diff(f_polynomial(x), x), 'numpy')(x_array_large)time_symb = 1000 * (time.time() - tic_symb)tic_jax = time.time()res_jax = vmap(grad(f_polynomial))(jnp.array(x_array_large.astype('float32')))time_jax = 1000 * (time.time() - tic_jax)print(f"Symbolic: {time_symb} ms")print(f"Automatic: {time_jax} ms")
Why is automatic differentiation faster?
Automatic differentiation builds the computation graph once and applies the chain rule efficiently. As function complexity increases, autodiff’s advantage grows because it avoids expression swell (symbolic) and repeated function evaluation (numerical).
Recommendation: Use automatic differentiation (JAX) for machine learning applications. It combines the accuracy of symbolic methods with superior computational efficiency.