Skip to main content
The Input class represents a batch of parallel tree traversals. Each item in the batch maintains its own position (index) and accumulated value as it traverses the tree over multiple rounds.

Dataclass: Input

A batch of parallel traversal tasks with starting indices and values.
problem.py:421-436
@dataclass
class Input:
    """
    A batch of inputs, indices to nodes (starting as 0) and initial input
    values. We then iterate these for a specified number of rounds.
    """

    indices: list[int]
    values: list[int]
    rounds: int

    @staticmethod
    def generate(forest: Tree, batch_size: int, rounds: int):
        indices = [0 for _ in range(batch_size)]
        values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
        return Input(indices, values, rounds)

Fields

indices
list[int]
required
Current tree node index for each batch item. All start at 0 (root) and are updated each traversal step.Length equals batch_size. Values range from 0 to len(tree.values) - 1.
values
list[int]
required
Current accumulated value for each batch item. These are XORed with node values, hashed, and updated at each step.Length equals batch_size. Values are 32-bit unsigned integers.
rounds
int
required
Number of traversal rounds to perform. Each round processes all batch items once.The typical test uses 16 rounds with 256 batch items, totaling 4096 traversal steps.

Methods

generate()

Generates a batch of random inputs for a given tree.
forest
Tree
required
The tree that will be traversed (used only to validate generation, not stored)
batch_size
int
required
Number of parallel traversals in the batch
rounds
int
required
Number of traversal rounds to perform
return
Input
A new Input with all indices at 0 and random 30-bit initial values
problem.py:432-436
@staticmethod
def generate(forest: Tree, batch_size: int, rounds: int):
    indices = [0 for _ in range(batch_size)]
    values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
    return Input(indices, values, rounds)
Example:
import random
from problem import Tree, Input

random.seed(123)
tree = Tree.generate(height=10)
inp = Input.generate(tree, batch_size=256, rounds=16)

print(f"Batch size: {len(inp.indices)}")
print(f"Rounds: {inp.rounds}")
print(f"Total steps: {len(inp.indices) * inp.rounds}")
print(f"Initial indices: {inp.indices[:5]}")  # All 0s
print(f"Initial values: {inp.values[:5]}")    # Random
Output:
Batch size: 256
Rounds: 16
Total steps: 4096
Initial indices: [0, 0, 0, 0, 0]
Initial values: [870398137, 926108597, 184120052, 728364899, 123456789]

Traversal Algorithm

Each batch item independently follows this algorithm for rounds iterations:
1

Read tree node

node_val = tree.values[indices[i]]
2

XOR and hash

values[i] = myhash(values[i] ^ node_val)
3

Choose next node

If hash is even: next = 2 * indices[i] + 1 (left child)If hash is odd: next = 2 * indices[i] + 2 (right child)
4

Wrap around

If next >= len(tree.values): next = 0 (back to root)
5

Update

indices[i] = next
problem.py:467-484
def reference_kernel(t: Tree, inp: Input):
    """
    Reference implementation of the kernel.

    A parallel tree traversal where at each node we set
    cur_inp_val = myhash(cur_inp_val ^ node_val)
    and then choose the left branch if cur_inp_val is even.
    If we reach the bottom of the tree we wrap around to the top.
    """
    for h in range(inp.rounds):
        for i in range(len(inp.indices)):
            idx = inp.indices[i]
            val = inp.values[i]
            val = myhash(val ^ t.values[idx])
            idx = 2 * idx + (1 if val % 2 == 0 else 2)
            idx = 0 if idx >= len(t.values) else idx
            inp.values[i] = val
            inp.indices[i] = idx

Memory Layout

Input data is stored in memory by build_mem_image():
problem.py:487-513
def build_mem_image(t: Tree, inp: Input) -> list[int]:
    header = 7
    mem = [0] * (header + len(t.values) + len(inp.indices) + len(inp.values) + extra)
    
    # Pointers in header
    inp_indices_p = header + len(t.values)
    inp_values_p = inp_indices_p + len(inp.values)
    
    mem[5] = inp_indices_p  # Pointer to indices array
    mem[6] = inp_values_p   # Pointer to values array
    
    # Data
    mem[inp_indices_p:inp_values_p] = inp.indices
    mem[inp_values_p:] = inp.values
    return mem
Memory map:
Header[5]: inp_indices_p  -> pointer to indices array
Header[6]: inp_values_p   -> pointer to values array

Indices section: [0, 0, 0, ...] (batch_size elements)
Values section:  [random, random, ...] (batch_size elements)
Both indices and values arrays are read AND written during execution. They represent the state that evolves over rounds.

Parallelization Opportunities

Independent Batch Items

Each batch item traverses independently - perfect for SIMD parallelism. Process 8 items at once using vector instructions.

Data Dependencies

Each step depends on the previous step’s index and value. No parallelism within a single item’s rounds.

Memory Access

Indices are pseudo-random, creating irregular memory access. Values are contiguous and vectorizable.

Batch Size

Typical batch_size=256 with VLEN=8 means 32 vector iterations. Consider unrolling and cache optimization.

Example: Single Traversal

import random
from problem import Tree, Input, myhash

random.seed(42)
tree = Tree.generate(height=2)  # 7 nodes
inp = Input.generate(tree, batch_size=3, rounds=1)

print("Initial state:")
for i in range(3):
    print(f"  Item {i}: index={inp.indices[i]}, value={inp.values[i]}")
print()

# Simulate one round
for i in range(len(inp.indices)):
    idx = inp.indices[i]
    val = inp.values[i]
    
    print(f"Item {i} at node {idx}:")
    node_val = tree.values[idx]
    print(f"  Node value: {node_val}")
    
    val = myhash(val ^ node_val)
    print(f"  After hash: {val}")
    
    next_idx = 2 * idx + (1 if val % 2 == 0 else 2)
    if next_idx >= len(tree.values):
        next_idx = 0
    print(f"  Next index: {next_idx}")
    
    inp.indices[i] = next_idx
    inp.values[i] = val
    print()

print("After 1 round:")
for i in range(3):
    print(f"  Item {i}: index={inp.indices[i]}, value={inp.values[i]}")

Vectorization Strategy

The key to optimization is processing multiple batch items in parallel:
# Scalar (baseline) - processes 1 item per iteration
for i in range(batch_size):
    idx = indices[i]
    val = values[i]
    # ... process one item ...

# Vector (optimized) - processes 8 items per iteration
for i in range(0, batch_size, 8):
    idx_vec = indices[i:i+8]     # Load 8 indices
    val_vec = values[i:i+8]      # Load 8 values
    # ... process 8 items with valu/vload/vstore ...
Use vload to load 8 contiguous values, but indices are scattered - you’ll need 8 separate scalar loads or gather operations for tree node values.

Test Configuration

The performance test uses these parameters:
perf_takehome.py:257-258
def test_kernel_cycles(self):
    do_kernel_test(10, 16, 256)  # height=10, rounds=16, batch_size=256
Tree height
10
2047 nodes (2^11 - 1)
Batch size
256
256 parallel traversals
Rounds
16
16 rounds
Total steps
4096
256 × 16 = 4,096 traversal steps
Baseline cycles
147,734
Scalar implementation from build_kernel()

See Also

Tree

The tree data structure being traversed

KernelBuilder

Build optimized kernels for batch processing

Machine

Simulator that executes the traversal

The batch size (256) is a multiple of VLEN (8), making it ideal for vectorization. Process 8 items together using vector instructions to achieve 8x throughput on the value computations and array updates.

Build docs developers (and LLMs) love