Skip to main content

Overview

The foreach pattern allows you to dynamically fan out and process items in parallel. Unlike static branching where you specify branch names, foreach creates a separate task for each item in a collection.
from metaflow import FlowSpec, step

class ForeachFlow(FlowSpec):
    @step
    def start(self):
        """Create a dynamic fan-out."""
        self.genres = ['Action', 'Comedy', 'Drama', 'Sci-Fi']
        # Create parallel tasks, one per genre
        self.next(self.process_genre, foreach='genres')
    
    @step
    def process_genre(self):
        """This runs once for EACH genre."""
        # self.input contains the current genre
        print(f"Processing: {self.input}")
        self.result = f"Processed {self.input}"
        self.next(self.join)
    
    @step
    def join(self, inputs):
        """Collect results from all parallel tasks."""
        self.all_results = [inp.result for inp in inputs]
        print(f"Processed {len(inputs)} genres")
        self.next(self.end)
    
    @step
    def end(self):
        pass

Real-World Example: Movie Statistics

From the Metaflow tutorials:
class MovieStatsFlow(FlowSpec):
    @step
    def start(self):
        """Load data and identify unique genres."""
        import csv
        
        columns = ['movie_title', 'title_year', 'genres', 'gross']
        self.dataframe = {col: [] for col in columns}
        
        for row in csv.DictReader(self.movie_data.splitlines()):
            for col in columns:
                val = int(row[col]) if col in ('title_year', 'gross') else row[col]
                self.dataframe[col].append(val)
        
        # Get all unique genres
        self.genres = {
            genre 
            for genres in self.dataframe['genres'] 
            for genre in genres.split('|')
        }
        self.genres = list(self.genres)
        
        # Compute statistics for each genre in parallel
        self.next(self.compute_statistics, foreach='genres')
    
    @step
    def compute_statistics(self):
        """Compute statistics for a single genre."""
        # The current genre being processed
        self.genre = self.input
        print(f"Computing statistics for {self.genre}")
        
        # Filter movies for this genre
        selector = [self.genre in row for row in self.dataframe['genres']]
        
        for col in self.dataframe.keys():
            self.dataframe[col] = [
                col for col, is_genre in zip(self.dataframe[col], selector)
                if is_genre
            ]
        
        # Calculate quartiles
        n_points = len(self.dataframe['movie_title'])
        self.quartiles = []
        for cut in [0.25, 0.5, 0.75]:
            idx = 0 if n_points < 2 else round(n_points * cut)
            self.quartiles.append(self.dataframe['gross'][idx])
        
        self.next(self.join)
    
    @step
    def join(self, inputs):
        """Merge results from all genres."""
        self.genre_stats = {
            inp.genre.lower(): {
                'quartiles': inp.quartiles,
                'dataframe': inp.dataframe
            }
            for inp in inputs
        }
        self.next(self.end)

The Foreach Pattern

Syntax

self.next(self.step_name, foreach='attribute_name')
  • step_name: The step to execute for each item
  • foreach: The name of an attribute containing an iterable

Implementation Details

From the source code:
# From flowspec.py:1090
try:
    foreach_iter = getattr(self, foreach)
except:
    msg = (
        "Foreach variable *self.{var}* in step *{step}* "
        "does not exist. Check your variable.".format(
            step=step, var=foreach
        )
    )
    raise InvalidNextException(msg)

# Calculate number of splits
self._foreach_num_splits = sum(1 for _ in foreach_iter)

if self._foreach_num_splits == 0:
    msg = (
        "Foreach iterator over *{var}* in step *{step}* "
        "produced zero splits. Check your variable.".format(
            step=step, var=foreach
        )
    )
    raise InvalidNextException(msg)

Accessing Foreach Data

self.input

In a foreach step, self.input contains the current item:
@step
def start(self):
    self.numbers = [1, 2, 3, 4, 5]
    self.next(self.square, foreach='numbers')

@step
def square(self):
    # self.input is 1, then 2, then 3, etc.
    self.result = self.input ** 2
    self.next(self.join)

self.index

The zero-based index of the current item:
@step
def square(self):
    print(f"Processing item {self.index}: {self.input}")
    # Output:
    # Processing item 0: 1
    # Processing item 1: 2
    # Processing item 2: 3
    # ...
    self.result = self.input ** 2
    self.next(self.join)
From the source code:
# From flowspec.py:607
@property
def index(self) -> Optional[int]:
    """The index of this foreach branch."""
    if self._foreach_stack:
        return self._foreach_stack[-1].index

Foreach Stack

For nested foreach loops, use foreach_stack():
@step
def start(self):
    self.groups = ['A', 'B', 'C']
    self.next(self.process_group, foreach='groups')

@step
def process_group(self):
    self.items = [1, 2, 3, 4]
    self.next(self.process_item, foreach='items')

@step
def process_item(self):
    # foreach_stack shows the full nesting
    stack = self.foreach_stack()
    # stack = [(0, 3, 'A'), (0, 4, 1)] for first item of first group
    # Each tuple: (index, num_splits, value)
    
    group_idx, group_count, group_val = stack[0]
    item_idx, item_count, item_val = stack[1]
    
    print(f"Group {group_val}[{group_idx}], Item {item_val}[{item_idx}]")
    self.next(self.join)
From the documentation in flowspec.py:
def foreach_stack(self) -> Optional[List[Tuple[int, int, Any]]]:
    """
    Returns the current stack of foreach indexes and values.
    
    Returns a list where each tuple contains:
    - The index of the task for that level of the loop
    - The number of splits for that level of the loop  
    - The value for that level of the loop
    """
    return [
        (frame.index, frame.num_splits, self._find_input(stack_index=i))
        for i, frame in enumerate(self._foreach_stack)
    ]

Nested Foreach

You can nest foreach loops for multi-dimensional processing:
class NestedForeachFlow(FlowSpec):
    @step
    def start(self):
        self.models = ['linear', 'tree', 'neural']
        self.next(self.train_model, foreach='models')
    
    @step
    def train_model(self):
        self.model_name = self.input
        self.param_sets = [
            {'lr': 0.01},
            {'lr': 0.001},
            {'lr': 0.0001}
        ]
        self.next(self.tune_params, foreach='param_sets')
    
    @step
    def tune_params(self):
        params = self.input
        self.score = train(self.model_name, params)
        self.next(self.join_params)
    
    @step
    def join_params(self, inputs):
        # Join the parameter tuning for this model
        best = max(inputs, key=lambda x: x.score)
        self.best_params = best.input
        self.best_score = best.score
        self.next(self.join_models)
    
    @step  
    def join_models(self, inputs):
        # Join all models
        self.results = {
            inp.model_name: {
                'params': inp.best_params,
                'score': inp.best_score
            }
            for inp in inputs
        }
        self.next(self.end)

Foreach vs Static Branching

FeatureStatic BranchingForeach
Number of branchesFixed at code timeDynamic at runtime
Branch specificationExplicit: self.next(self.a, self.b)Implicit: foreach='items'
Branch namingNamed methodsSame method, different data
Use caseDifferent operationsSame operation on different data

Iterable Types

Foreach works with any Python iterable:
# Lists
self.items = [1, 2, 3]
self.next(self.process, foreach='items')

# Tuples
self.items = ('a', 'b', 'c')
self.next(self.process, foreach='items')

# Dictionaries (iterates over keys)
self.items = {'x': 1, 'y': 2}
self.next(self.process, foreach='items')  # self.input will be 'x', then 'y'

# Ranges
self.items = range(10)
self.next(self.process, foreach='items')

# Generators (must be materialized)
self.items = list(x for x in range(10) if x % 2 == 0)
self.next(self.process, foreach='items')

Complex Items

Foreach items can be complex objects:
@step
def start(self):
    self.configs = [
        {'model': 'rf', 'params': {'n_trees': 100}},
        {'model': 'xgb', 'params': {'max_depth': 5}},
        {'model': 'nn', 'params': {'layers': [64, 32]}}
    ]
    self.next(self.train, foreach='configs')

@step
def train(self):
    config = self.input
    model_type = config['model']
    params = config['params']
    self.result = train_model(model_type, params)
    self.next(self.join)

Join Operations

Collecting Results

@step
def join(self, inputs):
    # List of all results
    self.results = [inp.result for inp in inputs]
    
    # Dictionary indexed by input value
    self.results_dict = {
        inp.input: inp.result 
        for inp in inputs
    }
    
    # Find best result
    best = max(inputs, key=lambda x: x.score)
    self.best_result = best.result
    
    self.next(self.end)

Aggregation

@step
def join(self, inputs):
    # Sum
    self.total = sum(inp.value for inp in inputs)
    
    # Average
    self.average = sum(inp.value for inp in inputs) / len(inputs)
    
    # Merge lists
    self.all_items = []
    for inp in inputs:
        self.all_items.extend(inp.items)
    
    self.next(self.end)

Error Handling

By default, if one foreach task fails, the entire flow fails. Use @catch for graceful handling:
from metaflow import catch

@catch(var='error')
@step
def process_item(self):
    """Process item, but catch failures."""
    if self.input == 'bad':
        raise ValueError("Bad item!")
    self.result = process(self.input)
    self.next(self.join)

@step
def join(self, inputs):
    # Filter out failed tasks
    successful = [inp for inp in inputs if inp.error is None]
    failed = [inp for inp in inputs if inp.error is not None]
    
    print(f"Success: {len(successful)}, Failed: {len(failed)}")
    self.results = [inp.result for inp in successful]
    self.next(self.end)

Performance Considerations

Number of Tasks

# Small number of tasks: OK
self.items = range(10)
self.next(self.process, foreach='items')  # 10 parallel tasks

# Large number of tasks: Consider batching
self.items = range(10000)
self.next(self.process, foreach='items')  # 10,000 tasks - may be too many!

Batching Strategy

@step
def start(self):
    items = range(10000)
    batch_size = 100
    
    # Create batches
    self.batches = [
        list(items[i:i+batch_size])
        for i in range(0, len(items), batch_size)
    ]
    # 100 tasks instead of 10,000
    self.next(self.process_batch, foreach='batches')

@step
def process_batch(self):
    batch = self.input
    self.results = [process_item(item) for item in batch]
    self.next(self.join)

Task Size

Balance between too many small tasks and too few large tasks:
# Too granular: Overhead dominates
self.items = [1, 2, 3, ...]  # 1000 tiny items

# Too coarse: No parallelism benefit
self.items = [all_data]  # 1 huge batch

# Good balance: Reasonable number of medium-sized tasks
self.items = [batch1, batch2, ...]  # 10-100 batches

Num Parallel

For simple parallel execution without data:
@step
def start(self):
    # Run 4 parallel instances
    self.next(self.train_model, num_parallel=4)

@step
def train_model(self):
    # self.input is the task index: 0, 1, 2, 3
    # self.index is also 0, 1, 2, 3
    seed = self.index
    self.result = train_with_random_seed(seed)
    self.next(self.join)
Implementation:
# From flowspec.py:1066
if num_parallel is not None and num_parallel >= 1:
    foreach = '_parallel_ubf_iter'
    self._parallel_ubf_iter = ParallelUBF(num_parallel)

Best Practices

1. Keep Foreach Steps Simple

# Good: One clear responsibility
@step
def process_item(self):
    self.result = transform(self.input)
    self.next(self.join)

# Avoid: Multiple operations
@step
def process_item(self):
    self.result1 = transform1(self.input)
    self.result2 = transform2(self.result1)
    self.result3 = transform3(self.result2)
    self.next(self.join)

2. Validate Before Foreach

@step
def start(self):
    self.items = load_items()
    
    # Validate before creating tasks
    if not self.items:
        raise ValueError("No items to process!")
    
    if len(self.items) > 1000:
        print("Warning: Creating many tasks")
    
    self.next(self.process, foreach='items')

3. Include Identifiers

@step
def process_item(self):
    # Store identifier for tracking
    self.item_id = self.input.get('id')
    self.result = process(self.input)
    self.next(self.join)

@step
def join(self, inputs):
    # Easy to map results back to items
    self.results = {
        inp.item_id: inp.result 
        for inp in inputs
    }
    self.next(self.end)

4. Use Appropriate Resources

from metaflow import resources

@resources(memory=4000, cpu=2)
@step
def process_item(self):
    # Each parallel task gets 4GB RAM, 2 CPUs
    self.result = expensive_operation(self.input)
    self.next(self.join)

Common Patterns

Map-Reduce

@step
def map_step(self):
    self.data = load_large_dataset()
    self.chunks = split_into_chunks(self.data, n=10)
    self.next(self.process_chunk, foreach='chunks')

@step
def process_chunk(self):
    # Map: Process each chunk
    self.partial_result = process(self.input)
    self.next(self.reduce)

@step
def reduce(self, inputs):
    # Reduce: Combine results
    self.final_result = combine([inp.partial_result for inp in inputs])
    self.next(self.end)
@step
def start(self):
    import itertools
    
    # Generate parameter grid
    param_grid = {
        'learning_rate': [0.01, 0.001, 0.0001],
        'batch_size': [32, 64, 128],
        'layers': [[64], [64, 32], [128, 64, 32]]
    }
    
    keys = param_grid.keys()
    values = param_grid.values()
    self.param_combinations = [
        dict(zip(keys, v)) 
        for v in itertools.product(*values)
    ]
    
    self.next(self.train, foreach='param_combinations')

@step
def train(self):
    params = self.input
    self.score = train_and_evaluate(params)
    self.next(self.select_best)

@step
def select_best(self, inputs):
    best = max(inputs, key=lambda x: x.score)
    self.best_params = best.input
    self.best_score = best.score
    self.next(self.end)

Next Steps

Error Handling

Handle failures in foreach tasks with @catch

Data Management

Manage data artifacts efficiently in parallel tasks

Branching

Learn about static branching patterns

Parameters

Use Parameters to configure foreach iterations

Build docs developers (and LLMs) love