Skip to main content

Overview

The performance challenge provides two reference implementations and a hash function that define the expected behavior. Your optimized kernel must produce identical results.
The reference kernels are not meant to be fast—they’re clear, correct specifications of what your kernel should compute.

reference_kernel() - High-Level Python

The simplest specification of the algorithm:
problem.py:467-484
def reference_kernel(t: Tree, inp: Input):
    """
    Reference implementation of the kernel.

    A parallel tree traversal where at each node we set
    cur_inp_val = myhash(cur_inp_val ^ node_val)
    and then choose the left branch if cur_inp_val is even.
    If we reach the bottom of the tree we wrap around to the top.
    """
    for h in range(inp.rounds):
        for i in range(len(inp.indices)):
            idx = inp.indices[i]
            val = inp.values[i]
            val = myhash(val ^ t.values[idx])  # XOR with node, then hash
            idx = 2 * idx + (1 if val % 2 == 0 else 2)  # Navigate tree
            idx = 0 if idx >= len(t.values) else idx     # Wrap to root
            inp.values[i] = val
            inp.indices[i] = idx

Algorithm Breakdown

Data Structure: Implicit perfect binary tree
  • Node 0 is the root
  • Node i has children at 2*i + 1 (left) and 2*i + 2 (right)
  • Total nodes: 2^(height+1) - 1
Traversal Rules:
problem.py:479-482
val = myhash(val ^ t.values[idx])              # Update value
idx = 2 * idx + (1 if val % 2 == 0 else 2)    # Go left if even, right if odd
idx = 0 if idx >= len(t.values) else idx       # Wrap to root if out of bounds
Why it wraps: After reaching a leaf node, the next index would exceed the tree size, so we reset to the root.
Batch Processing:
problem.py:477-478
for h in range(inp.rounds):           # Outer loop: rounds
    for i in range(len(inp.indices)): # Inner loop: batch of inputs
Inputs:
  • inp.indices - Current position of each input in the tree
  • inp.values - Current value for each input
  • inp.rounds - Number of traversal iterations
Independence: Each input’s traversal is independent (parallelism opportunity!).
problem.py:480
val = myhash(val ^ t.values[idx])
  1. Read the tree node value at current index
  2. XOR with the input value
  3. Hash the result using myhash()
This ensures the traversal path depends on both the input and tree structure.

reference_kernel2() - Memory-Based with Tracing

A lower-level version that operates on flat memory and records intermediate values:
problem.py:535-568
def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
    """
    Reference implementation of the kernel on a flat memory.
    """
    # Load parameters from memory header
    rounds = mem[0]
    n_nodes = mem[1]
    batch_size = mem[2]
    forest_height = mem[3]
    forest_values_p = mem[4]  # Pointer to tree values
    inp_indices_p = mem[5]    # Pointer to input indices
    inp_values_p = mem[6]     # Pointer to input values
    
    yield mem  # Initial state
    
    for h in range(rounds):
        for i in range(batch_size):
            idx = mem[inp_indices_p + i]
            trace[(h, i, "idx")] = idx
            
            val = mem[inp_values_p + i]
            trace[(h, i, "val")] = val
            
            node_val = mem[forest_values_p + idx]
            trace[(h, i, "node_val")] = node_val
            
            val = myhash_traced(val ^ node_val, trace, h, i)
            trace[(h, i, "hashed_val")] = val
            
            idx = 2 * idx + (1 if val % 2 == 0 else 2)
            trace[(h, i, "next_idx")] = idx
            
            idx = 0 if idx >= n_nodes else idx
            trace[(h, i, "wrapped_idx")] = idx
            
            mem[inp_values_p + i] = val
            mem[inp_indices_p + i] = idx
    
    yield mem  # Final state

Memory Layout

The build_mem_image() function creates a flat memory image:
problem.py:487-513
def build_mem_image(t: Tree, inp: Input) -> list[int]:
    header = 7
    extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
    mem = [0] * (header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room)
    
    forest_values_p = header
    inp_indices_p = forest_values_p + len(t.values)
    inp_values_p = inp_indices_p + len(inp.values)
    
    # Header:
    mem[0] = inp.rounds
    mem[1] = len(t.values)
    mem[2] = len(inp.indices)
    mem[3] = t.height
    mem[4] = forest_values_p
    mem[5] = inp_indices_p
    mem[6] = inp_values_p
    
    # Data:
    mem[forest_values_p:inp_indices_p] = t.values
    mem[inp_indices_p:inp_values_p] = inp.indices
    mem[inp_values_p:] = inp.values
    
    return mem
Memory Map:
[0] rounds
[1] n_nodes
[2] batch_size
[3] forest_height
[4] forest_values_p  (pointer to tree data)
[5] inp_indices_p    (pointer to input indices)
[6] inp_values_p     (pointer to input values)
[7+] tree values     (n_nodes elements)
[...] input indices  (batch_size elements)
[...] input values   (batch_size elements)
[...] extra scratch space

Tracing for Debugging

reference_kernel2() populates a trace dictionary with intermediate values:
problem.py:551-562
trace[(h, i, "idx")] = idx
trace[(h, i, "val")] = val
trace[(h, i, "node_val")] = node_val
trace[(h, i, "hashed_val")] = val
trace[(h, i, "next_idx")] = idx
trace[(h, i, "wrapped_idx")] = idx
Your kernel uses debug instructions to compare against this trace:
perf_takehome.py:140
body.append(("debug", ("compare", tmp_idx, (round, i, "idx"))))
The simulator validates your results match the reference at every step.
Debugging strategy: When your kernel fails validation, the error message shows which (round, i, stage) mismatches. Use the trace to understand what went wrong.

myhash() - The Hash Function

A 32-bit hash function with 6 stages:
problem.py:449-464
def myhash(a: int) -> int:
    """A simple 32-bit hash function"""
    fns = {
        "+": lambda x, y: x + y,
        "^": lambda x, y: x ^ y,
        "<<": lambda x, y: x << y,
        ">>": lambda x, y: x >> y,
    }

    def r(x):
        return x % (2**32)  # Keep values in 32-bit range

    for op1, val1, op2, op3, val3 in HASH_STAGES:
        a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))

    return a

HASH_STAGES Configuration

problem.py:439-446
HASH_STAGES = [
    ("+", 0x7ED55D16, "+", "<<", 12),
    ("^", 0xC761C23C, "^", ">>", 19),
    ("+", 0x165667B1, "+", "<<", 5),
    ("+", 0xD3A2646C, "^", "<<", 9),
    ("+", 0xFD7046C5, "+", "<<", 3),
    ("^", 0xB55A4F09, "^", ">>", 16),
]
Each stage tuple: (op1, val1, op2, op3, val3) Stage computation:
tmp1 = op1(a, val1)      # e.g., a + 0x7ED55D16
tmp2 = op3(a, val3)      # e.g., a << 12
a = op2(tmp1, tmp2)      # e.g., tmp1 + tmp2

Hash in KernelBuilder

The build_hash() method generates instructions for this:
perf_takehome.py:77-86
def build_hash(self, val_hash_addr, tmp1, tmp2, round, i):
    slots = []
    for hi, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
        slots.append(("alu", (op1, tmp1, val_hash_addr, self.scratch_const(val1))))
        slots.append(("alu", (op3, tmp2, val_hash_addr, self.scratch_const(val3))))
        slots.append(("alu", (op2, val_hash_addr, tmp1, tmp2)))
        slots.append(("debug", ("compare", val_hash_addr, (round, i, "hash_stage", hi))))
    return slots
Performance Impact:
  • 6 stages × 4 slots = 24 slots per hash
  • Each input hashes once per round: batch_size * rounds * 24 slots
  • For test case (256 inputs, 16 rounds): 98,304 hash slots!

Optimization Target

The hash function dominates runtime. VLIW packing here yields massive gains:
  • Baseline: 24 cycles per hash (one instruction per slot)
  • Optimized: ~8 cycles per hash (pack 3 operations per bundle)
  • 3x speedup just from hash optimization!

Correctness Validation

The testing harness compares your kernel against the references:
perf_takehome.py:205-221
machine = Machine(mem, kb.instrs, kb.debug_info(), value_trace=value_trace)

for i, ref_mem in enumerate(reference_kernel2(mem, value_trace)):
    machine.run()  # Run until next pause
    
    # Check output values
    inp_values_p = ref_mem[6]
    assert (
        machine.mem[inp_values_p : inp_values_p + len(inp.values)]
        == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
    ), f"Incorrect result on round {i}"
Validation Points:
  1. After initial setup (first pause instruction)
  2. After complete computation (final pause)
  3. At every debug compare instruction (if enabled)
Critical: Your kernel must include pause instructions matching the reference’s yield statements, or validation will fail.

Using the References

1

Understand the Algorithm

Read reference_kernel() to grasp the high-level logic:
  • What operations happen per input?
  • What’s the data flow?
  • Where’s the parallelism?
2

Study the Memory Layout

Examine build_mem_image() and reference_kernel2():
  • Where are tree values stored?
  • How are input arrays organized?
  • What’s the header structure?
3

Trace Your Kernel

Run with tracing enabled to compare instruction-by-instruction:
python perf_takehome.py Tests.test_kernel_trace
Debug instructions validate against the trace at every step.
4

Verify Correctness

After optimization, always validate:
python tests/submission_tests.py
This runs the frozen reference to ensure your results are correct.

Key Takeaways

Algorithm

Tree traversal with hash-based navigation, fully specified by reference_kernel()

Memory

Flat layout defined by build_mem_image(), accessed via pointers in header

Hash Function

6-stage computation defined by HASH_STAGES, major optimization target

Validation

Debug instructions compare against reference_kernel2() trace

Next Steps

Apply Optimization Strategies

Now that you understand the algorithm, learn how to make it fast.

Build docs developers (and LLMs) love