MLX supports executing operations on different devices (CPU, GPU) and coordinating execution through streams. Streams allow for asynchronous execution and fine-grained control over when operations run.
Devices
Device
Represents a computation device in MLX.
import mlx.core as mx
# Create devices
cpu = mx.Device(mx.cpu)
gpu = mx.Device(mx.gpu)
print(cpu) # Device(cpu, 0)
print(gpu) # Device(gpu, 0)
Devices have two components:
- Device type: Either
cpu or gpu
- Device index: Integer identifying which device (useful for multi-GPU systems)
Common device types:
import mlx.core as mx
mx.cpu # CPU device type
mx.gpu # GPU device type
default_device
Get the default device for operations.
import mlx.core as mx
device = mx.default_device()
print(device) # Device(gpu, 0) or Device(cpu, 0)
The current default device.
set_default_device
Set the default device for operations.
import mlx.core as mx
# Set CPU as default
mx.set_default_device(mx.cpu)
# Now arrays are created on CPU by default
a = mx.array([1, 2, 3])
# Set GPU as default
mx.set_default_device(mx.gpu)
# Now arrays are created on GPU by default
b = mx.array([4, 5, 6])
The device to set as default.
device_count
Get the number of devices available.
import mlx.core as mx
num_gpus = mx.device_count(mx.gpu)
print(f"Number of GPUs: {num_gpus}")
num_cpus = mx.device_count(mx.cpu)
print(f"Number of CPUs: {num_cpus}") # Usually 1
The type of device to count (e.g., mx.cpu or mx.gpu).
The number of available devices of the specified type.
device_info
Get information about a device.
import mlx.core as mx
info = mx.device_info(mx.gpu)
print(info)
# Returns device-specific information like memory, architecture, etc.
Dictionary containing device information.
Streams
Stream
Represents an execution stream for scheduling operations.
import mlx.core as mx
# Create a new stream on the GPU
stream = mx.Stream(mx.gpu)
# Schedule operations on this stream
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream)
Streams allow operations to execute asynchronously. Operations on different streams can run in parallel, while operations on the same stream execute sequentially.
Stream properties:
The device associated with the stream.
default_stream
Get the default stream for a device.
import mlx.core as mx
# Get default stream for GPU
stream = mx.default_stream(mx.gpu)
# Get default stream for CPU
cpu_stream = mx.default_stream(mx.cpu)
The device to get the default stream for.
The default stream for the specified device.
set_default_stream
Set the default stream for a device.
import mlx.core as mx
# Create a custom stream
custom_stream = mx.Stream(mx.gpu)
# Set it as the default for GPU operations
mx.set_default_stream(custom_stream)
# Now operations use this stream by default
a = mx.array([1, 2, 3])
The stream to set as default.
new_stream
Create a new stream on a device.
import mlx.core as mx
# Create new stream on GPU
stream1 = mx.new_stream(mx.gpu)
stream2 = mx.new_stream(mx.gpu)
# Operations on different streams can run in parallel
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream1)
c = mx.multiply(a, 2, stream=stream2)
The device to create the stream on.
A new stream for the specified device.
stream
Context manager for temporarily setting the default stream.
import mlx.core as mx
# Create a custom stream
custom_stream = mx.new_stream(mx.gpu)
# Use it as default within a context
with mx.stream(custom_stream):
# Operations in this block use custom_stream
a = mx.array([1, 2, 3])
b = mx.add(a, 1)
# Outside the context, the original default stream is restored
c = mx.array([4, 5, 6])
The stream to use as default within the context.
synchronize
Wait for operations on a stream to complete.
import mlx.core as mx
# Schedule operations
a = mx.array([1, 2, 3])
b = mx.add(a, 1)
# Wait for all operations on default stream to complete
mx.synchronize()
# Now it's safe to access the result
print(b) # array([2, 3, 4], dtype=int32)
The stream to synchronize. If not provided, synchronizes the default stream.
Using Streams for Parallelism
Streams enable parallelism by allowing operations to overlap:
import mlx.core as mx
# Create two streams
stream1 = mx.new_stream(mx.gpu)
stream2 = mx.new_stream(mx.gpu)
# These operations can run in parallel
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream1) # Runs on stream1
c = mx.multiply(a, 2, stream=stream2) # Runs on stream2 (in parallel)
# Wait for both to complete
mx.synchronize(stream1)
mx.synchronize(stream2)
print(b) # array([2, 3, 4], dtype=int32)
print(c) # array([2, 4, 6], dtype=int32)
Device-Specific Operations
You can explicitly specify which device an operation should run on:
import mlx.core as mx
# Create array on CPU
a = mx.array([1, 2, 3])
# Perform operation on GPU
b = mx.add(a, 1, stream=mx.default_stream(mx.gpu))
# Perform operation on CPU
c = mx.multiply(a, 2, stream=mx.default_stream(mx.cpu))
Best Practices
Lazy Evaluation
MLX uses lazy evaluation - operations are not executed immediately but scheduled on a stream. To ensure an operation completes before accessing results:
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.add(a, 1)
# Force evaluation
mx.eval(b) # or mx.synchronize()
# Now safe to use b
print(b)
Multi-GPU Usage
For systems with multiple GPUs:
import mlx.core as mx
num_gpus = mx.device_count(mx.gpu)
print(f"Available GPUs: {num_gpus}")
# Use specific GPU
if num_gpus > 1:
gpu0 = mx.Device(mx.gpu, 0)
gpu1 = mx.Device(mx.gpu, 1)
stream0 = mx.default_stream(gpu0)
stream1 = mx.default_stream(gpu1)
# Schedule operations on different GPUs
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream0) # GPU 0
c = mx.add(a, 2, stream=stream1) # GPU 1
Stream Safety
Operations on the same stream execute in order:
import mlx.core as mx
stream = mx.new_stream(mx.gpu)
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream)
c = mx.multiply(b, 2, stream=stream) # Waits for b to complete
mx.synchronize(stream)
print(c) # array([4, 6, 8], dtype=int32)