Skip to main content
Metaflow’s @parallel decorator enables multi-node distributed computing for workloads that require coordination across multiple machines. This is essential for distributed training, large-scale simulations, and other parallel workloads.

Overview

The @parallel decorator creates a cluster of worker nodes that can communicate with each other:
from metaflow import FlowSpec, step, parallel, current

class DistributedFlow(FlowSpec):
    @step
    def start(self):
        # Launch 4 parallel workers
        self.next(self.train, num_parallel=4)
    
    @parallel
    @step
    def train(self):
        # Each worker knows its position in the cluster
        print(f"Node {current.parallel.node_index} of {current.parallel.num_nodes}")
        print(f"Main node IP: {current.parallel.main_ip}")
        
        # Distributed computation
        result = train_on_node(current.parallel.node_index)
        self.result = result
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        # Collect results from all workers
        all_results = [inp.result for inp in inputs]
        print(f"Completed with {len(all_results)} workers")
        self.next(self.end)
    
    @step
    def end(self):
        pass

How It Works

Cluster Topology

When you use @parallel, Metaflow creates a cluster:
  1. Control node (index 0) - The main coordinator
  2. Worker nodes (index 1, 2, …) - Additional compute nodes
All nodes run the same code but can coordinate through:
  • Shared network access via current.parallel.main_ip
  • Node identification via current.parallel.node_index
  • Cluster size via current.parallel.num_nodes
@parallel
@step
def distributed_step(self):
    if current.parallel.node_index == 0:
        print("I am the control node")
        # Coordinate the work
    else:
        print(f"I am worker {current.parallel.node_index}")
        # Do assigned work

Basic Usage

Simple Parallel Computation

from metaflow import FlowSpec, step, parallel, current

class SimpleParallel(FlowSpec):
    @step
    def start(self):
        self.next(self.compute, num_parallel=3)
    
    @parallel
    @step
    def compute(self):
        import time
        
        # Each node does independent work
        my_index = current.parallel.node_index
        result = my_index ** 2
        
        print(f"Node {my_index} computed {result}")
        self.result = result
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        results = [inp.result for inp in inputs]
        print(f"Results: {results}")  # [0, 1, 4]
        self.next(self.end)
    
    @step
    def end(self):
        pass

With Remote Compute

Combine @parallel with @batch or @kubernetes:
from metaflow import FlowSpec, step, parallel, kubernetes, resources

class CloudParallel(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=4)
    
    @parallel
    @kubernetes
    @resources(cpu=8, memory=32000, gpu=1)
    @step
    def train(self):
        # Each node: 8 CPUs, 32GB RAM, 1 GPU
        # Total: 4 nodes = 32 CPUs, 128GB RAM, 4 GPUs
        train_distributed()
        self.next(self.join)
    
    @step
    def join(self, inputs):
        self.next(self.end)
    
    @step
    def end(self):
        pass

The current.parallel Object

Metaflow provides information about the parallel cluster:
from metaflow import current

@parallel
@step
def my_step(self):
    # Main node IP address for communication
    main_ip = current.parallel.main_ip  # e.g., "10.0.1.5"
    
    # Total number of nodes in cluster
    num_nodes = current.parallel.num_nodes  # e.g., 4
    
    # This node's index (0 is control node)
    node_index = current.parallel.node_index  # e.g., 0, 1, 2, or 3
    
    # Control task ID for coordination
    control_id = current.parallel.control_task_id  # e.g., "123"
    
    print(f"Node {node_index}/{num_nodes} at {main_ip}")
Access from outside the step:
@step
def regular_step(self):
    # Check if we're in a parallel step
    if current.is_parallel:
        print("This is a parallel step")
    else:
        print("This is a regular step")

Distributed Training Patterns

PyTorch Distributed

import torch
import torch.distributed as dist
from metaflow import FlowSpec, step, parallel, current, kubernetes

class PyTorchDistributed(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=4)
    
    @parallel
    @kubernetes(gpu=1, memory=32000)
    @step
    def train(self):
        # Initialize process group
        dist.init_process_group(
            backend='nccl',
            init_method=f'tcp://{current.parallel.main_ip}:29500',
            rank=current.parallel.node_index,
            world_size=current.parallel.num_nodes
        )
        
        # Create model and wrap with DDP
        model = MyModel().cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
        
        # Train with distributed sampler
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=current.parallel.num_nodes,
            rank=current.parallel.node_index
        )
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=32,
            sampler=train_sampler
        )
        
        # Training loop
        for epoch in range(10):
            train_sampler.set_epoch(epoch)
            for batch in train_loader:
                loss = train_step(model, batch)
                loss.backward()
                optimizer.step()
        
        # Save model from rank 0
        if current.parallel.node_index == 0:
            torch.save(model.state_dict(), 'model.pt')
            self.model_path = 'model.pt'
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        # Only control node saved the model
        self.model_path = inputs[0].model_path
        self.next(self.end)
    
    @step
    def end(self):
        print(f"Model saved to {self.model_path}")

TensorFlow MultiWorkerMirroredStrategy

import tensorflow as tf
from metaflow import FlowSpec, step, parallel, current, batch
import json

class TFDistributed(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=4)
    
    @parallel
    @batch(gpu=1, memory=32000)
    @step
    def train(self):
        # Configure TF_CONFIG for MultiWorkerMirroredStrategy
        tf_config = {
            'cluster': {
                'worker': [
                    f'{current.parallel.main_ip}:12345',
                    # Add more workers if needed
                ]
            },
            'task': {
                'type': 'worker',
                'index': current.parallel.node_index
            }
        }
        
        import os
        os.environ['TF_CONFIG'] = json.dumps(tf_config)
        
        # Create strategy
        strategy = tf.distribute.MultiWorkerMirroredStrategy()
        
        # Build and train model
        with strategy.scope():
            model = create_model()
            model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
        
        # Train
        model.fit(train_dataset, epochs=10)
        
        # Save from first worker
        if current.parallel.node_index == 0:
            model.save('trained_model')
            self.model_path = 'trained_model'
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        self.model_path = inputs[0].model_path
        self.next(self.end)
    
    @step
    def end(self):
        print(f"Training complete: {self.model_path}")

Horovod

import horovod.torch as hvd
from metaflow import FlowSpec, step, parallel, current, kubernetes

class HorovodTraining(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=8)
    
    @parallel
    @kubernetes(gpu=1, cpu=8, memory=32000)
    @step
    def train(self):
        # Initialize Horovod
        hvd.init()
        
        # Pin GPU to local rank
        import torch
        torch.cuda.set_device(hvd.local_rank())
        
        # Create model
        model = MyModel().cuda()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        
        # Wrap optimizer with Horovod
        optimizer = hvd.DistributedOptimizer(
            optimizer,
            named_parameters=model.named_parameters()
        )
        
        # Broadcast initial parameters
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        
        # Training loop
        for epoch in range(10):
            for batch in train_loader:
                loss = train_step(model, batch)
                loss.backward()
                optimizer.step()
        
        # Save from rank 0
        if hvd.rank() == 0:
            torch.save(model.state_dict(), 'model.pt')
            self.model_path = 'model.pt'
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        self.model_path = inputs[0].model_path
        self.next(self.end)
    
    @step
    def end(self):
        print(f"Horovod training complete: {self.model_path}")

Communication Between Nodes

Using Sockets

import socket

@parallel
@step
def communicate(self):
    if current.parallel.node_index == 0:
        # Control node: Listen for connections
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.bind(('0.0.0.0', 9999))
        sock.listen(current.parallel.num_nodes - 1)
        
        messages = []
        for _ in range(current.parallel.num_nodes - 1):
            conn, addr = sock.accept()
            data = conn.recv(1024).decode()
            messages.append(data)
            conn.close()
        
        print(f"Received: {messages}")
        self.messages = messages
    else:
        # Workers: Send to control node
        import time
        time.sleep(2)  # Give control node time to start listening
        
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect((current.parallel.main_ip, 9999))
        sock.send(f"Hello from worker {current.parallel.node_index}".encode())
        sock.close()

Using Redis for Coordination

import redis

@parallel
@step
def coordinate_with_redis(self):
    # Connect to Redis (running separately)
    r = redis.Redis(host=current.parallel.main_ip, port=6379)
    
    # Each worker writes its status
    r.set(f'worker:{current.parallel.node_index}', 'ready')
    
    # Wait for all workers
    import time
    while True:
        ready_count = len(r.keys('worker:*'))
        if ready_count == current.parallel.num_nodes:
            break
        time.sleep(1)
    
    print("All workers ready!")
    
    # Do coordinated work
    result = compute_work(current.parallel.node_index)
    r.set(f'result:{current.parallel.node_index}', result)
    
    # Control node collects results
    if current.parallel.node_index == 0:
        results = [r.get(f'result:{i}') for i in range(current.parallel.num_nodes)]
        self.all_results = results

Platform-Specific Configuration

AWS Batch Multi-Node

from metaflow import FlowSpec, step, parallel, batch, resources

class BatchMultiNode(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=4)
    
    @parallel
    @batch(efa=1)  # Elastic Fabric Adapter for low latency
    @resources(cpu=16, memory=64000)
    @step
    def train(self):
        # AWS Batch sets these automatically:
        # AWS_BATCH_JOB_NUM_NODES
        # AWS_BATCH_JOB_NODE_INDEX
        # AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS
        
        print(f"Running on AWS Batch node {current.parallel.node_index}")
        distributed_training()
        self.next(self.join)
    
    @step
    def join(self, inputs):
        self.next(self.end)
    
    @step
    def end(self):
        pass

Kubernetes JobSets

from metaflow import FlowSpec, step, parallel, kubernetes, resources

class K8sMultiNode(FlowSpec):
    @step
    def start(self):
        self.next(self.train, num_parallel=4)
    
    @parallel
    @kubernetes
    @resources(cpu=8, memory=32000, gpu=1)
    @step
    def train(self):
        # Kubernetes creates JobSet with control and worker pods
        # Environment variables set:
        # MF_MASTER_ADDR - control pod hostname
        # MF_WORLD_SIZE - total number of pods
        # MF_WORKER_REPLICA_INDEX - worker index (or MF_CONTROL_INDEX for control)
        
        print(f"K8s node {current.parallel.node_index}")
        distributed_training()
        self.next(self.join)
    
    @step
    def join(self, inputs):
        self.next(self.end)
    
    @step
    def end(self):
        pass

Local Testing

Test multi-node code locally:
class TestParallel(FlowSpec):
    @step
    def start(self):
        self.next(self.compute, num_parallel=2)
    
    @parallel
    @step
    def compute(self):
        # Runs multiple processes locally
        print(f"Local node {current.parallel.node_index}")
        self.result = current.parallel.node_index * 2
        self.next(self.join)
    
    @step
    def join(self, inputs):
        results = [inp.result for inp in inputs]
        print(f"Results: {results}")  # [0, 2]
        self.next(self.end)
    
    @step
    def end(self):
        pass
Run locally:
python myflow.py run
# Creates multiple local processes that communicate via 127.0.0.1

Best Practices

Only the control node (index 0) should save final outputs to avoid race conditions:
@parallel
@step
def train(self):
    train_distributed()
    
    if current.parallel.node_index == 0:
        self.model = save_model()  # Only control node saves
Choose the right backend for your framework:
  • PyTorch: nccl for GPU, gloo for CPU
  • TensorFlow: MultiWorkerMirroredStrategy
  • Horovod: Automatic backend selection
Implement checkpointing and recovery:
@parallel
@step
def train(self):
    if checkpoint_exists():
        load_checkpoint()
    
    for epoch in range(start_epoch, num_epochs):
        train_epoch()
        if current.parallel.node_index == 0:
            save_checkpoint(epoch)
More nodes isn’t always better. Consider:
  • Communication overhead increases with cluster size
  • Batch size per worker decreases with more workers
  • Diminishing returns beyond certain cluster sizes
  • Start small and scale based on metrics

Common Issues

Symptoms: Timeout errors, connection refusedSolutions:
  • Verify security groups allow inter-node communication
  • Check firewall rules in Kubernetes
  • Ensure correct port numbers
  • Verify current.parallel.main_ip is accessible
Symptoms: Some nodes finish much earlier than othersSolutions:
  • Use distributed samplers to balance data
  • Implement dynamic work distribution
  • Profile per-node utilization
  • Adjust batch sizes
Symptoms: OOM errors during multi-node trainingSolutions:
  • Reduce per-node batch size
  • Enable gradient checkpointing
  • Use mixed precision training
  • Increase memory allocation
Symptoms: Non-deterministic training resultsSolutions:
  • Set random seeds on all nodes
  • Synchronize initial model weights
  • Use deterministic algorithms
  • Verify data shuffling is consistent

Performance Optimization

Network Bandwidth

Optimize communication:
@parallel
@batch(efa=1)  # Use EFA for high bandwidth
@step
def train(self):
    # Use gradient compression
    import torch.distributed as dist
    dist.init_process_group(
        backend='nccl',
        init_method=f'tcp://{current.parallel.main_ip}:29500',
        rank=current.parallel.node_index,
        world_size=current.parallel.num_nodes
    )
    
    # Enable gradient compression
    from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
    model = torch.nn.parallel.DistributedDataParallel(model)
    model.register_comm_hook(
        state=None,
        hook=default_hooks.fp16_compress_hook
    )

Overlapping Computation and Communication

# Use bucketing to overlap gradient computation with communication
model = torch.nn.parallel.DistributedDataParallel(
    model,
    bucket_cap_mb=25,  # Smaller buckets = more overlap
    find_unused_parameters=True
)

Next Steps

Resources Management

Configure CPU, memory, and GPU for distributed jobs

AWS Batch

Run distributed jobs on AWS Batch

Kubernetes

Deploy distributed workloads on Kubernetes

Remote Execution

Understand remote execution fundamentals

Build docs developers (and LLMs) love