The Input class represents a batch of parallel tree traversals. Each item in the batch maintains its own position (index) and accumulated value as it traverses the tree over multiple rounds.
A batch of parallel traversal tasks with starting indices and values.
problem.py:421-436
@dataclassclass Input: """ A batch of inputs, indices to nodes (starting as 0) and initial input values. We then iterate these for a specified number of rounds. """ indices: list[int] values: list[int] rounds: int @staticmethod def generate(forest: Tree, batch_size: int, rounds: int): indices = [0 for _ in range(batch_size)] values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)] return Input(indices, values, rounds)
Current tree node index for each batch item. All start at 0 (root) and are updated each traversal step.Length equals batch_size. Values range from 0 to len(tree.values) - 1.
Current accumulated value for each batch item. These are XORed with node values, hashed, and updated at each step.Length equals batch_size. Values are 32-bit unsigned integers.
Number of traversal rounds to perform. Each round processes all batch items once.The typical test uses 16 rounds with 256 batch items, totaling 4096 traversal steps.
Each batch item independently follows this algorithm for rounds iterations:
1
Read tree node
node_val = tree.values[indices[i]]
2
XOR and hash
values[i] = myhash(values[i] ^ node_val)
3
Choose next node
If hash is even: next = 2 * indices[i] + 1 (left child)If hash is odd: next = 2 * indices[i] + 2 (right child)
4
Wrap around
If next >= len(tree.values): next = 0 (back to root)
5
Update
indices[i] = next
problem.py:467-484
def reference_kernel(t: Tree, inp: Input): """ Reference implementation of the kernel. A parallel tree traversal where at each node we set cur_inp_val = myhash(cur_inp_val ^ node_val) and then choose the left branch if cur_inp_val is even. If we reach the bottom of the tree we wrap around to the top. """ for h in range(inp.rounds): for i in range(len(inp.indices)): idx = inp.indices[i] val = inp.values[i] val = myhash(val ^ t.values[idx]) idx = 2 * idx + (1 if val % 2 == 0 else 2) idx = 0 if idx >= len(t.values) else idx inp.values[i] = val inp.indices[i] = idx
The key to optimization is processing multiple batch items in parallel:
# Scalar (baseline) - processes 1 item per iterationfor i in range(batch_size): idx = indices[i] val = values[i] # ... process one item ...# Vector (optimized) - processes 8 items per iterationfor i in range(0, batch_size, 8): idx_vec = indices[i:i+8] # Load 8 indices val_vec = values[i:i+8] # Load 8 values # ... process 8 items with valu/vload/vstore ...
Use vload to load 8 contiguous values, but indices are scattered - you’ll need 8 separate scalar loads or gather operations for tree node values.
The batch size (256) is a multiple of VLEN (8), making it ideal for vectorization. Process 8 items together using vector instructions to achieve 8x throughput on the value computations and array updates.