Skip to main content
This guide explores three approaches to computing derivatives in Python: symbolic differentiation with SymPy, numerical differentiation with NumPy, and automatic differentiation with JAX. Understanding these methods is crucial for implementing machine learning algorithms efficiently.

Functions in Python

Let’s start with defining a simple function f(x)=x2f(x) = x^2 and its derivative:
def f(x):
    return x**2

print(f(3))  # Output: 9
The analytical derivative can be defined as a separate function:
def dfdx(x):
    return 2*x

print(dfdx(3))  # Output: 6

Working with NumPy Arrays

Apply functions to NumPy arrays element-wise:
import numpy as np

x_array = np.array([1, 2, 3])

print("x:", x_array)
print("f(x) = x**2:", f(x_array))
print("f'(x) = 2x:", dfdx(x_array))
Functions defined for scalars automatically work with NumPy arrays through vectorization.

Symbolic Differentiation

Symbolic computation produces exact derivatives, similar to manual derivation.

Introduction to SymPy

SymPy represents mathematical objects exactly:
from sympy import *

# Symbolic computation preserves exact values
print(sqrt(18))  # Output: 3*sqrt(2)

# Numerical evaluation with specified precision
print(N(sqrt(18), 8))  # Output: 4.2426407

Defining Symbolic Expressions

Create symbolic variables and expressions:
# Define symbols
x, y = symbols('x y')

# Create expression: 2x² - xy
expr = 2 * x**2 - x * y
print(expr)

# Manipulate expressions
expr_manip = x * (expr + x * y + x**3)
print(expand(expr_manip))  # Expand
print(factor(expr_manip))  # Factorize

Evaluating Symbolic Functions

Substitute values into expressions:
# Evaluate at specific point
result = expr.evalf(subs={x: -1, y: 2})
print(result)  # Output: 4.0

# Define function symbolically
f_symb = x ** 2
print(f_symb.evalf(subs={x: 3}))  # Output: 9.0

Making SymPy Functions NumPy-Compatible

Convert symbolic functions for array operations:
from sympy.utilities.lambdify import lambdify

f_symb_numpy = lambdify(x, f_symb, 'numpy')

x_array = np.array([1, 2, 3])
print("f(x) = x**2:", f_symb_numpy(x_array))
# Output: [1 4 9]

Computing Derivatives with SymPy

1

Simple Power Function

# Derivative of x³
print(diff(x**3, x))  # Output: 3*x**2
2

Composed Functions

# SymPy applies chain, product, and sum rules automatically
dfdx_composed = diff(exp(-2*x) + 3*sin(3*x), x)
print(dfdx_composed)
# Output: -2*exp(-2*x) + 9*cos(3*x)
3

Create NumPy-Friendly Derivative

dfdx_symb = diff(f_symb, x)
dfdx_symb_numpy = lambdify(x, dfdx_symb, 'numpy')

print("f'(x) = 2x:", dfdx_symb_numpy(x_array))
# Output: [2 4 6]

Limitations of Symbolic Differentiation

Symbolic differentiation struggles with discontinuous functions and can produce overly complex expressions (expression swell), leading to inefficient computation.
Example with absolute value function:
# Derivative of |x|
dfdx_abs = diff(abs(x), x)
print(dfdx_abs)  # Complex unevaluated expression

# Evaluation fails for negative values
print(dfdx_abs.evalf(subs={x: -2}))  # Returns unevaluated expression

Numerical Differentiation

Numerical differentiation approximates derivatives using finite differences.

Using NumPy’s Gradient Function

import matplotlib.pyplot as plt

x_array_2 = np.linspace(-5, 5, 100)

# Compute numerical derivative
dfdx_numerical = np.gradient(f(x_array_2), x_array_2)

# Compare with exact derivative
plt.plot(x_array_2, dfdx_symb_numpy(x_array_2), 'r', label="Exact")
plt.plot(x_array_2, dfdx_numerical, 'bo', markersize=3, label="Numerical")
plt.legend()
plt.show()

Numerical Differentiation for Complex Functions

def f_composed(x):
    return np.exp(-2*x) + 3*np.sin(3*x)

# Numerical derivative works regardless of function complexity
dfdx_num = np.gradient(f_composed(x_array_2), x_array_2)
Numerical differentiation doesn’t require knowledge of the function’s formula—only its values. This makes it versatile but less accurate than symbolic methods.

Limitations of Numerical Differentiation

  1. Inaccuracy at discontinuities: Similar to symbolic methods
  2. Computational expense: Requires function evaluation at multiple points
  3. Not exact: Provides approximations with potential rounding errors
Numerical differentiation is slow for machine learning because it requires full function evaluation for each parameter—impractical for models with hundreds or thousands of parameters.

Automatic Differentiation

Automatic differentiation (autodiff) combines accuracy with efficiency by building a computational graph and applying the chain rule.

Introduction to JAX

JAX provides automatic differentiation with NumPy-like syntax:
from jax import grad, vmap
import jax.numpy as jnp

# Create JAX array
x_array_jnp = jnp.array([1., 2., 3.])
print("Type:", type(x_array_jnp))
# Type: <class 'jaxlib.xla_extension.ArrayImpl'>

JAX Array Operations

Most NumPy operations work with JAX arrays:
print(x_array_jnp * 2)  # Element-wise multiplication
print(x_array_jnp[2])   # Indexing
JAX arrays are immutable. Use .at[i].set(value) to create updated copies:
y_array_jnp = x_array_jnp.at[2].set(4.0)
print(y_array_jnp)  # [1. 2. 4.]

Computing Derivatives with JAX

1

Single Point Derivative

print("Function at x=3:", f(3.0))
print("Derivative at x=3:", grad(f)(3.0))
# Function at x=3: 9.0
# Derivative at x=3: 6.0
2

Array of Derivatives

# Use vmap for array operations
dfdx_jax_vmap = vmap(grad(f))(x_array_jnp)
print(dfdx_jax_vmap)  # [2. 4. 6.]
3

Visualize Results

# Plot function and derivative
x_plot = jnp.linspace(-5, 5, 100)
plt.plot(x_plot, f(x_plot), 'r', label='f(x)')
plt.plot(x_plot, vmap(grad(f))(x_plot), 'b', label="f'(x)")
plt.legend()
plt.show()

Common Derivatives with JAX

def g(x):
    # Uncomment to test different functions
    # return x**3
    # return 2*x**3 - 3*x**2 + 5
    # return 1/x
    # return jnp.exp(x)
    # return jnp.log(x)
    # return jnp.sin(x)
    # return jnp.cos(x)
    return jnp.abs(x)
    # return jnp.abs(x) + jnp.sin(x)*jnp.cos(x)

# JAX handles all these derivatives automatically
plot_f1_and_f2(g, vmap(grad(g)))
JAX successfully differentiates abs(x) and other functions where symbolic methods struggle.

Computational Efficiency Comparison

Compare execution times for different differentiation methods:
import time

x_array_large = np.linspace(-5, 5, 1000000)

# Symbolic differentiation
tic_symb = time.time()
res_symb = lambdify(x, diff(f(x), x), 'numpy')(x_array_large)
toc_symb = time.time()
time_symb = 1000 * (toc_symb - tic_symb)

# Numerical differentiation
tic_numerical = time.time()
res_numerical = np.gradient(f(x_array_large), x_array_large)
toc_numerical = time.time()
time_numerical = 1000 * (toc_numerical - tic_numerical)

# Automatic differentiation
tic_jax = time.time()
res_jax = vmap(grad(f))(jnp.array(x_array_large.astype('float32')))
toc_jax = time.time()
time_jax = 1000 * (toc_jax - tic_jax)

print(f"Time\nSymbolic: {time_symb} ms")
print(f"Numerical: {time_numerical} ms")
print(f"Automatic: {time_jax} ms")

Performance for Complex Functions

For polynomial compositions, automatic differentiation excels:
def f_polynomial_simple(x):
    return 2*x**3 - 3*x**2 + 5

def f_polynomial(x):
    for i in range(3):
        x = f_polynomial_simple(x)
    return x

# Compare symbolic vs automatic
tic_symb = time.time()
res_symb = lambdify(x, diff(f_polynomial(x), x), 'numpy')(x_array_large)
time_symb = 1000 * (time.time() - tic_symb)

tic_jax = time.time()
res_jax = vmap(grad(f_polynomial))(jnp.array(x_array_large.astype('float32')))
time_jax = 1000 * (time.time() - tic_jax)

print(f"Symbolic: {time_symb} ms")
print(f"Automatic: {time_jax} ms")
Automatic differentiation builds the computation graph once and applies the chain rule efficiently. As function complexity increases, autodiff’s advantage grows because it avoids expression swell (symbolic) and repeated function evaluation (numerical).

Summary

MethodProsConsBest Use Case
SymbolicExact results, algebraic manipulationExpression swell, discontinuitiesSimple functions, symbolic analysis
NumericalWorks with any functionSlow, approximate, discontinuitiesQuick approximations, black-box functions
AutomaticFast, exact, handles complexityRequires compatible libraryMachine learning, neural networks
Recommendation: Use automatic differentiation (JAX) for machine learning applications. It combines the accuracy of symbolic methods with superior computational efficiency.

Next Steps

Now that you understand differentiation methods, learn how to apply them:

Build docs developers (and LLMs) love