Overview
Theforeach pattern allows you to dynamically fan out and process items in parallel. Unlike static branching where you specify branch names, foreach creates a separate task for each item in a collection.
from metaflow import FlowSpec, step
class ForeachFlow(FlowSpec):
@step
def start(self):
"""Create a dynamic fan-out."""
self.genres = ['Action', 'Comedy', 'Drama', 'Sci-Fi']
# Create parallel tasks, one per genre
self.next(self.process_genre, foreach='genres')
@step
def process_genre(self):
"""This runs once for EACH genre."""
# self.input contains the current genre
print(f"Processing: {self.input}")
self.result = f"Processed {self.input}"
self.next(self.join)
@step
def join(self, inputs):
"""Collect results from all parallel tasks."""
self.all_results = [inp.result for inp in inputs]
print(f"Processed {len(inputs)} genres")
self.next(self.end)
@step
def end(self):
pass
Real-World Example: Movie Statistics
From the Metaflow tutorials:class MovieStatsFlow(FlowSpec):
@step
def start(self):
"""Load data and identify unique genres."""
import csv
columns = ['movie_title', 'title_year', 'genres', 'gross']
self.dataframe = {col: [] for col in columns}
for row in csv.DictReader(self.movie_data.splitlines()):
for col in columns:
val = int(row[col]) if col in ('title_year', 'gross') else row[col]
self.dataframe[col].append(val)
# Get all unique genres
self.genres = {
genre
for genres in self.dataframe['genres']
for genre in genres.split('|')
}
self.genres = list(self.genres)
# Compute statistics for each genre in parallel
self.next(self.compute_statistics, foreach='genres')
@step
def compute_statistics(self):
"""Compute statistics for a single genre."""
# The current genre being processed
self.genre = self.input
print(f"Computing statistics for {self.genre}")
# Filter movies for this genre
selector = [self.genre in row for row in self.dataframe['genres']]
for col in self.dataframe.keys():
self.dataframe[col] = [
col for col, is_genre in zip(self.dataframe[col], selector)
if is_genre
]
# Calculate quartiles
n_points = len(self.dataframe['movie_title'])
self.quartiles = []
for cut in [0.25, 0.5, 0.75]:
idx = 0 if n_points < 2 else round(n_points * cut)
self.quartiles.append(self.dataframe['gross'][idx])
self.next(self.join)
@step
def join(self, inputs):
"""Merge results from all genres."""
self.genre_stats = {
inp.genre.lower(): {
'quartiles': inp.quartiles,
'dataframe': inp.dataframe
}
for inp in inputs
}
self.next(self.end)
The Foreach Pattern
Syntax
self.next(self.step_name, foreach='attribute_name')
step_name: The step to execute for each itemforeach: The name of an attribute containing an iterable
Implementation Details
From the source code:# From flowspec.py:1090
try:
foreach_iter = getattr(self, foreach)
except:
msg = (
"Foreach variable *self.{var}* in step *{step}* "
"does not exist. Check your variable.".format(
step=step, var=foreach
)
)
raise InvalidNextException(msg)
# Calculate number of splits
self._foreach_num_splits = sum(1 for _ in foreach_iter)
if self._foreach_num_splits == 0:
msg = (
"Foreach iterator over *{var}* in step *{step}* "
"produced zero splits. Check your variable.".format(
step=step, var=foreach
)
)
raise InvalidNextException(msg)
Accessing Foreach Data
self.input
In a foreach step,self.input contains the current item:
@step
def start(self):
self.numbers = [1, 2, 3, 4, 5]
self.next(self.square, foreach='numbers')
@step
def square(self):
# self.input is 1, then 2, then 3, etc.
self.result = self.input ** 2
self.next(self.join)
self.index
The zero-based index of the current item:@step
def square(self):
print(f"Processing item {self.index}: {self.input}")
# Output:
# Processing item 0: 1
# Processing item 1: 2
# Processing item 2: 3
# ...
self.result = self.input ** 2
self.next(self.join)
# From flowspec.py:607
@property
def index(self) -> Optional[int]:
"""The index of this foreach branch."""
if self._foreach_stack:
return self._foreach_stack[-1].index
Foreach Stack
For nested foreach loops, useforeach_stack():
@step
def start(self):
self.groups = ['A', 'B', 'C']
self.next(self.process_group, foreach='groups')
@step
def process_group(self):
self.items = [1, 2, 3, 4]
self.next(self.process_item, foreach='items')
@step
def process_item(self):
# foreach_stack shows the full nesting
stack = self.foreach_stack()
# stack = [(0, 3, 'A'), (0, 4, 1)] for first item of first group
# Each tuple: (index, num_splits, value)
group_idx, group_count, group_val = stack[0]
item_idx, item_count, item_val = stack[1]
print(f"Group {group_val}[{group_idx}], Item {item_val}[{item_idx}]")
self.next(self.join)
def foreach_stack(self) -> Optional[List[Tuple[int, int, Any]]]:
"""
Returns the current stack of foreach indexes and values.
Returns a list where each tuple contains:
- The index of the task for that level of the loop
- The number of splits for that level of the loop
- The value for that level of the loop
"""
return [
(frame.index, frame.num_splits, self._find_input(stack_index=i))
for i, frame in enumerate(self._foreach_stack)
]
Nested Foreach
You can nest foreach loops for multi-dimensional processing:class NestedForeachFlow(FlowSpec):
@step
def start(self):
self.models = ['linear', 'tree', 'neural']
self.next(self.train_model, foreach='models')
@step
def train_model(self):
self.model_name = self.input
self.param_sets = [
{'lr': 0.01},
{'lr': 0.001},
{'lr': 0.0001}
]
self.next(self.tune_params, foreach='param_sets')
@step
def tune_params(self):
params = self.input
self.score = train(self.model_name, params)
self.next(self.join_params)
@step
def join_params(self, inputs):
# Join the parameter tuning for this model
best = max(inputs, key=lambda x: x.score)
self.best_params = best.input
self.best_score = best.score
self.next(self.join_models)
@step
def join_models(self, inputs):
# Join all models
self.results = {
inp.model_name: {
'params': inp.best_params,
'score': inp.best_score
}
for inp in inputs
}
self.next(self.end)
Foreach vs Static Branching
| Feature | Static Branching | Foreach |
|---|---|---|
| Number of branches | Fixed at code time | Dynamic at runtime |
| Branch specification | Explicit: self.next(self.a, self.b) | Implicit: foreach='items' |
| Branch naming | Named methods | Same method, different data |
| Use case | Different operations | Same operation on different data |
Iterable Types
Foreach works with any Python iterable:# Lists
self.items = [1, 2, 3]
self.next(self.process, foreach='items')
# Tuples
self.items = ('a', 'b', 'c')
self.next(self.process, foreach='items')
# Dictionaries (iterates over keys)
self.items = {'x': 1, 'y': 2}
self.next(self.process, foreach='items') # self.input will be 'x', then 'y'
# Ranges
self.items = range(10)
self.next(self.process, foreach='items')
# Generators (must be materialized)
self.items = list(x for x in range(10) if x % 2 == 0)
self.next(self.process, foreach='items')
Complex Items
Foreach items can be complex objects:@step
def start(self):
self.configs = [
{'model': 'rf', 'params': {'n_trees': 100}},
{'model': 'xgb', 'params': {'max_depth': 5}},
{'model': 'nn', 'params': {'layers': [64, 32]}}
]
self.next(self.train, foreach='configs')
@step
def train(self):
config = self.input
model_type = config['model']
params = config['params']
self.result = train_model(model_type, params)
self.next(self.join)
Join Operations
Collecting Results
@step
def join(self, inputs):
# List of all results
self.results = [inp.result for inp in inputs]
# Dictionary indexed by input value
self.results_dict = {
inp.input: inp.result
for inp in inputs
}
# Find best result
best = max(inputs, key=lambda x: x.score)
self.best_result = best.result
self.next(self.end)
Aggregation
@step
def join(self, inputs):
# Sum
self.total = sum(inp.value for inp in inputs)
# Average
self.average = sum(inp.value for inp in inputs) / len(inputs)
# Merge lists
self.all_items = []
for inp in inputs:
self.all_items.extend(inp.items)
self.next(self.end)
Error Handling
By default, if one foreach task fails, the entire flow fails. Use@catch for graceful handling:
from metaflow import catch
@catch(var='error')
@step
def process_item(self):
"""Process item, but catch failures."""
if self.input == 'bad':
raise ValueError("Bad item!")
self.result = process(self.input)
self.next(self.join)
@step
def join(self, inputs):
# Filter out failed tasks
successful = [inp for inp in inputs if inp.error is None]
failed = [inp for inp in inputs if inp.error is not None]
print(f"Success: {len(successful)}, Failed: {len(failed)}")
self.results = [inp.result for inp in successful]
self.next(self.end)
Performance Considerations
Number of Tasks
# Small number of tasks: OK
self.items = range(10)
self.next(self.process, foreach='items') # 10 parallel tasks
# Large number of tasks: Consider batching
self.items = range(10000)
self.next(self.process, foreach='items') # 10,000 tasks - may be too many!
Batching Strategy
@step
def start(self):
items = range(10000)
batch_size = 100
# Create batches
self.batches = [
list(items[i:i+batch_size])
for i in range(0, len(items), batch_size)
]
# 100 tasks instead of 10,000
self.next(self.process_batch, foreach='batches')
@step
def process_batch(self):
batch = self.input
self.results = [process_item(item) for item in batch]
self.next(self.join)
Task Size
Balance between too many small tasks and too few large tasks:# Too granular: Overhead dominates
self.items = [1, 2, 3, ...] # 1000 tiny items
# Too coarse: No parallelism benefit
self.items = [all_data] # 1 huge batch
# Good balance: Reasonable number of medium-sized tasks
self.items = [batch1, batch2, ...] # 10-100 batches
Num Parallel
For simple parallel execution without data:@step
def start(self):
# Run 4 parallel instances
self.next(self.train_model, num_parallel=4)
@step
def train_model(self):
# self.input is the task index: 0, 1, 2, 3
# self.index is also 0, 1, 2, 3
seed = self.index
self.result = train_with_random_seed(seed)
self.next(self.join)
# From flowspec.py:1066
if num_parallel is not None and num_parallel >= 1:
foreach = '_parallel_ubf_iter'
self._parallel_ubf_iter = ParallelUBF(num_parallel)
Best Practices
1. Keep Foreach Steps Simple
# Good: One clear responsibility
@step
def process_item(self):
self.result = transform(self.input)
self.next(self.join)
# Avoid: Multiple operations
@step
def process_item(self):
self.result1 = transform1(self.input)
self.result2 = transform2(self.result1)
self.result3 = transform3(self.result2)
self.next(self.join)
2. Validate Before Foreach
@step
def start(self):
self.items = load_items()
# Validate before creating tasks
if not self.items:
raise ValueError("No items to process!")
if len(self.items) > 1000:
print("Warning: Creating many tasks")
self.next(self.process, foreach='items')
3. Include Identifiers
@step
def process_item(self):
# Store identifier for tracking
self.item_id = self.input.get('id')
self.result = process(self.input)
self.next(self.join)
@step
def join(self, inputs):
# Easy to map results back to items
self.results = {
inp.item_id: inp.result
for inp in inputs
}
self.next(self.end)
4. Use Appropriate Resources
from metaflow import resources
@resources(memory=4000, cpu=2)
@step
def process_item(self):
# Each parallel task gets 4GB RAM, 2 CPUs
self.result = expensive_operation(self.input)
self.next(self.join)
Common Patterns
Map-Reduce
@step
def map_step(self):
self.data = load_large_dataset()
self.chunks = split_into_chunks(self.data, n=10)
self.next(self.process_chunk, foreach='chunks')
@step
def process_chunk(self):
# Map: Process each chunk
self.partial_result = process(self.input)
self.next(self.reduce)
@step
def reduce(self, inputs):
# Reduce: Combine results
self.final_result = combine([inp.partial_result for inp in inputs])
self.next(self.end)
Grid Search
@step
def start(self):
import itertools
# Generate parameter grid
param_grid = {
'learning_rate': [0.01, 0.001, 0.0001],
'batch_size': [32, 64, 128],
'layers': [[64], [64, 32], [128, 64, 32]]
}
keys = param_grid.keys()
values = param_grid.values()
self.param_combinations = [
dict(zip(keys, v))
for v in itertools.product(*values)
]
self.next(self.train, foreach='param_combinations')
@step
def train(self):
params = self.input
self.score = train_and_evaluate(params)
self.next(self.select_best)
@step
def select_best(self, inputs):
best = max(inputs, key=lambda x: x.score)
self.best_params = best.input
self.best_score = best.score
self.next(self.end)
Next Steps
Error Handling
Handle failures in foreach tasks with @catch
Data Management
Manage data artifacts efficiently in parallel tasks
Branching
Learn about static branching patterns
Parameters
Use Parameters to configure foreach iterations
