Skip to main content

Overview

Branching allows you to execute multiple steps in parallel, then join the results. This is essential for:
  • Processing independent data transformations
  • Running multiple models simultaneously
  • A/B testing different approaches
  • Comparing algorithm variants

Static Fan-Out

The simplest form of branching is static fan-out, where you explicitly name the parallel branches:
from metaflow import FlowSpec, step

class BranchingFlow(FlowSpec):
    @step
    def start(self):
        """Split into multiple parallel branches."""
        self.data = [1, 2, 3, 4, 5]
        # Launch three branches in parallel
        self.next(self.branch_a, self.branch_b, self.branch_c)
    
    @step
    def branch_a(self):
        self.result = sum(self.data)
        self.next(self.join)
    
    @step
    def branch_b(self):
        self.result = max(self.data)
        self.next(self.join)
    
    @step
    def branch_c(self):
        self.result = min(self.data)
        self.next(self.join)
    
    @step
    def join(self, inputs):
        """Join the parallel branches."""
        print(f"Sum: {inputs.branch_a.result}")
        print(f"Max: {inputs.branch_b.result}")
        print(f"Min: {inputs.branch_c.result}")
        self.next(self.end)
    
    @step
    def end(self):
        pass

Real-World Example: Model Comparison

Here’s a practical example from the playlist tutorial:
class PlayListFlow(FlowSpec):
    @step
    def start(self):
        """Load movie data."""
        import csv
        columns = ['movie_title', 'genres']
        self.dataframe = {col: [] for col in columns}
        
        for row in csv.DictReader(self.movie_data.splitlines()):
            for col in columns:
                self.dataframe[col].append(row[col])
        
        # Compute genre-specific movies and a bonus movie in parallel
        self.next(self.bonus_movie, self.genre_movies)
    
    @step
    def bonus_movie(self):
        """Choose a random movie from a different genre."""
        from random import choice
        
        movies = [
            (movie, genres)
            for movie, genres in zip(
                self.dataframe['movie_title'],
                self.dataframe['genres']
            )
            if self.genre.lower() not in genres.lower()
        ]
        
        self.bonus = choice(movies)
        self.next(self.join)
    
    @step
    def genre_movies(self):
        """Filter movies by genre."""
        from random import shuffle
        
        self.movies = [
            movie
            for movie, genres in zip(
                self.dataframe['movie_title'],
                self.dataframe['genres']
            )
            if self.genre.lower() in genres.lower()
        ]
        
        shuffle(self.movies)
        self.next(self.join)
    
    @step
    def join(self, inputs):
        """Merge results from both branches."""
        self.playlist = inputs.genre_movies.movies
        self.bonus = inputs.bonus_movie.bonus
        self.next(self.end)

Understanding Join Steps

Accessing Branch Data

In a join step, the inputs parameter provides access to all incoming branches:
@step
def join(self, inputs):
    # Access by branch name
    result_a = inputs.branch_a.some_artifact
    result_b = inputs.branch_b.some_artifact
    
    # Iterate over all inputs
    for inp in inputs:
        print(inp.result)
    
    # List comprehension
    all_results = [inp.result for inp in inputs]
    self.next(self.end)

Join Step Signature

Join steps are detected automatically by their signature:
# From graph.py:218
if self.name == 'start':
    self.type = 'start'
elif self.num_args > 1:  # Has 'inputs' parameter
    self.type = 'join'
else:
    self.type = 'linear'
The presence of the inputs parameter makes it a join step.

Graph Structure

Branching creates a directed acyclic graph (DAG):
# From graph.py:308
def traverse(node, seen, split_parents, split_branches):
    if node.type in ('split', 'foreach'):
        node.split_parents = split_parents
        split_parents = split_parents + [node.name]
    elif node.type == 'join':
        if split_parents:
            self[split_parents[-1]].matching_join = node.name
Each split tracks its matching join for proper graph traversal.

Merge Artifacts Helper

Use merge_artifacts() to automatically merge artifacts from branches:
@step
def join(self, inputs):
    """Automatically merge artifacts from branches."""
    # Before calling merge_artifacts:
    # - All branches have artifact 'data'
    # - Branch A has 'a_specific'
    # - Branch B has 'b_specific'
    
    self.merge_artifacts(inputs)
    
    # After merge_artifacts:
    # - self.data is available (if same in all branches)
    # - self.a_specific is available
    # - self.b_specific is available
    
    self.next(self.end)

Handling Conflicts

If branches set the same artifact to different values:
@step
def branch_a(self):
    self.result = "A"
    self.next(self.join)

@step
def branch_b(self):
    self.result = "B"  # Conflict!
    self.next(self.join)

@step
def join(self, inputs):
    # Manually resolve conflict BEFORE merge_artifacts
    self.result = max(inp.result for inp in inputs)
    
    # Now merge_artifacts won't complain
    self.merge_artifacts(inputs)
    self.next(self.end)

Selective Merging

@step
def join(self, inputs):
    # Only merge specific artifacts
    self.merge_artifacts(inputs, include=['data', 'model'])
    
    # Or exclude specific artifacts
    self.merge_artifacts(inputs, exclude=['temp', 'cache'])
    
    self.next(self.end)

Nested Branching

Branches can have their own branches:
@step
def start(self):
    self.next(self.branch_a, self.branch_b)

@step
def branch_a(self):
    # Branch A splits further
    self.next(self.a1, self.a2)

@step
def a1(self):
    self.next(self.join_a)

@step
def a2(self):
    self.next(self.join_a)

@step
def join_a(self, inputs):
    # Join for branch_a's sub-branches
    self.merge_artifacts(inputs)
    self.next(self.final_join)

@step
def branch_b(self):
    # Branch B is linear
    self.next(self.final_join)

@step
def final_join(self, inputs):
    # Final join of branch_a and branch_b
    self.merge_artifacts(inputs)
    self.next(self.end)

Best Practices

1. Independent Branches

Branches should be independent and not rely on execution order:
# Good: Each branch is independent
@step
def start(self):
    self.data = load_data()
    self.next(self.model_a, self.model_b)

@step
def model_a(self):
    self.score = train_model_a(self.data)  # Uses self.data
    self.next(self.join)

@step
def model_b(self):
    self.score = train_model_b(self.data)  # Uses self.data
    self.next(self.join)

2. Clear Branch Names

Use descriptive names that explain what each branch does:
# Good
self.next(self.train_random_forest, self.train_xgboost, self.train_neural_net)

# Bad
self.next(self.branch1, self.branch2, self.branch3)

3. Balance Branch Workload

Try to distribute work evenly across branches:
# Consider: Does one branch take much longer?
@step
def start(self):
    self.next(
        self.quick_model,      # 1 minute
        self.expensive_model   # 30 minutes
    )
    # Overall runtime: 30 minutes (waiting for expensive_model)

4. Handle Missing Artifacts

Not all branches may produce the same artifacts:
@step
def join(self, inputs):
    # Safe: Check if artifact exists
    scores = [inp.score for inp in inputs if hasattr(inp, 'score')]
    self.average_score = sum(scores) / len(scores)
    self.next(self.end)

Switch Statements

For conditional branching, see the switch pattern:
@step
def start(self):
    self.env = 'production'
    self.next({
        'development': self.dev_path,
        'staging': self.staging_path,
        'production': self.prod_path
    }, condition='env')
Switch statements take exactly one branch based on the condition variable. This is different from parallel fan-out which executes all branches.

Visualizing Branches

View your flow graph to see branches:
python myflow.py show
Or output to a file:
python myflow.py output-dot > flow.dot
dot -Tpng flow.dot -o flow.png

Common Patterns

A/B Testing

@step
def start(self):
    self.data = load_data()
    self.next(self.variant_a, self.variant_b)

@step
def variant_a(self):
    self.metric = run_variant_a(self.data)
    self.next(self.compare)

@step
def variant_b(self):
    self.metric = run_variant_b(self.data)
    self.next(self.compare)

@step
def compare(self, inputs):
    best = max(inputs, key=lambda x: x.metric)
    self.winner = best.name
    self.next(self.end)

Data Pipeline

@step
def start(self):
    self.raw_data = load_raw_data()
    self.next(self.clean, self.validate, self.enrich)

@step
def clean(self):
    self.cleaned = clean_data(self.raw_data)
    self.next(self.join)

@step
def validate(self):
    self.validation_report = validate_data(self.raw_data)
    self.next(self.join)

@step
def enrich(self):
    self.enriched = enrich_data(self.raw_data)
    self.next(self.join)

@step
def join(self, inputs):
    self.merge_artifacts(inputs)
    # All artifacts available: cleaned, validation_report, enriched
    self.next(self.end)

Next Steps

Foreach Patterns

Process dynamic data in parallel with foreach

Data Management

Best practices for managing artifacts across branches

Error Handling

Handle failures in parallel branches

FlowSpec

Learn more about the FlowSpec base class

Build docs developers (and LLMs) love