Here’s a practical example from the playlist tutorial:
class PlayListFlow(FlowSpec): @step def start(self): """Load movie data.""" import csv columns = ['movie_title', 'genres'] self.dataframe = {col: [] for col in columns} for row in csv.DictReader(self.movie_data.splitlines()): for col in columns: self.dataframe[col].append(row[col]) # Compute genre-specific movies and a bonus movie in parallel self.next(self.bonus_movie, self.genre_movies) @step def bonus_movie(self): """Choose a random movie from a different genre.""" from random import choice movies = [ (movie, genres) for movie, genres in zip( self.dataframe['movie_title'], self.dataframe['genres'] ) if self.genre.lower() not in genres.lower() ] self.bonus = choice(movies) self.next(self.join) @step def genre_movies(self): """Filter movies by genre.""" from random import shuffle self.movies = [ movie for movie, genres in zip( self.dataframe['movie_title'], self.dataframe['genres'] ) if self.genre.lower() in genres.lower() ] shuffle(self.movies) self.next(self.join) @step def join(self, inputs): """Merge results from both branches.""" self.playlist = inputs.genre_movies.movies self.bonus = inputs.bonus_movie.bonus self.next(self.end)
In a join step, the inputs parameter provides access to all incoming branches:
@stepdef join(self, inputs): # Access by branch name result_a = inputs.branch_a.some_artifact result_b = inputs.branch_b.some_artifact # Iterate over all inputs for inp in inputs: print(inp.result) # List comprehension all_results = [inp.result for inp in inputs] self.next(self.end)
Use merge_artifacts() to automatically merge artifacts from branches:
@stepdef join(self, inputs): """Automatically merge artifacts from branches.""" # Before calling merge_artifacts: # - All branches have artifact 'data' # - Branch A has 'a_specific' # - Branch B has 'b_specific' self.merge_artifacts(inputs) # After merge_artifacts: # - self.data is available (if same in all branches) # - self.a_specific is available # - self.b_specific is available self.next(self.end)
@stepdef join(self, inputs): # Only merge specific artifacts self.merge_artifacts(inputs, include=['data', 'model']) # Or exclude specific artifacts self.merge_artifacts(inputs, exclude=['temp', 'cache']) self.next(self.end)
@stepdef start(self): self.next(self.branch_a, self.branch_b)@stepdef branch_a(self): # Branch A splits further self.next(self.a1, self.a2)@stepdef a1(self): self.next(self.join_a)@stepdef a2(self): self.next(self.join_a)@stepdef join_a(self, inputs): # Join for branch_a's sub-branches self.merge_artifacts(inputs) self.next(self.final_join)@stepdef branch_b(self): # Branch B is linear self.next(self.final_join)@stepdef final_join(self, inputs): # Final join of branch_a and branch_b self.merge_artifacts(inputs) self.next(self.end)