Overview
The data collection module provides functions for playing game episodes and collecting tokenized training data from multiple agents.
Functions
play_episode
def play_episode(
game: SnakeGame,
agent,
codec: EventCodec | None = None,
max_ticks: int = 200,
) -> tuple[dict[int, list[Event]], dict[int, SnakeState]]
Play a single episode with the given agent, collecting events and states at each tick.
The game instance to play. Should be initialized with desired width, height, and seed.
Agent object with an act(state, legal_actions) method that returns an Action.
codec
EventCodec | None
default:"None"
Optional codec for encoding. Not used by play_episode itself, but passed for consistency.
Maximum number of ticks to run before terminating the episode.
Dictionary mapping tick numbers to lists of events that occurred in that tick.
Dictionary mapping tick numbers to game states at those ticks. Always includes tick 0.
Example:
from game_grammar.snake import SnakeGame
from game_grammar.data import play_episode
from game_grammar.agents import RandomAgent
# Create game and agent
game = SnakeGame(width=10, height=10, seed=42)
agent = RandomAgent()
# Play one episode
events_by_tick, states_by_tick = play_episode(
game=game,
agent=agent,
max_ticks=200
)
# Inspect results
print(f"Episode lasted {max(states_by_tick.keys())} ticks")
print(f"Final score: {states_by_tick[max(states_by_tick.keys())].score}")
The episode automatically terminates if the snake dies (state.alive is False) or reaches max_ticks.
collect_episodes
def collect_episodes(
n: int,
agent_mix: list[tuple[object, float]],
width: int = 10,
height: int = 10,
codec: EventCodec | None = None,
max_ticks: int = 200,
seed: int = 42,
) -> list[list[int]]
Collect multiple tokenized episodes from a weighted mixture of agents.
Number of episodes to collect.
agent_mix
list[tuple[object, float]]
required
List of (agent, weight) pairs for weighted random sampling. Weights don’t need to sum to 1.
Grid height for the game.
codec
EventCodec | None
default:"None"
Codec for tokenization. If None, creates a default EventCodec() with standard settings.
Maximum ticks per episode.
Random seed for reproducibility. Controls both agent selection and game initialization.
List of tokenized episodes, where each episode is a list of token IDs.
Example:
from game_grammar.data import collect_episodes
from game_grammar.agents import RandomAgent, GreedyAgent, HybridAgent
from game_grammar.codec import EventCodec
# Define agent mixture
agent_mix = [
(RandomAgent(), 0.3), # 30% random
(GreedyAgent(), 0.5), # 50% greedy
(HybridAgent(), 0.2), # 20% hybrid
]
# Create custom codec
codec = EventCodec(
snapshot_interval=16,
salience_threshold=Salience.TICK
)
# Collect 1000 episodes
episodes = collect_episodes(
n=1000,
agent_mix=agent_mix,
width=10,
height=10,
codec=codec,
max_ticks=200,
seed=42
)
print(f"Collected {len(episodes)} episodes")
print(f"Token counts: {[len(ep) for ep in episodes[:5]]}")
Agent mix configuration:
# Equal weighting
agent_mix = [
(RandomAgent(), 1.0),
(GreedyAgent(), 1.0),
]
# Heavy weighting toward better agents
agent_mix = [
(RandomAgent(), 0.1),
(GreedyAgent(), 0.9),
]
# Single agent (no mixing)
agent_mix = [(GreedyAgent(), 1.0)]
Each episode uses a fresh game instance with a new random seed derived from the main seed. This ensures variability across episodes.
Usage Patterns
Basic data collection
from game_grammar.data import collect_episodes
from game_grammar.agents import RandomAgent
# Collect training data from random play
episodes = collect_episodes(
n=500,
agent_mix=[(RandomAgent(), 1.0)],
seed=42
)
Multi-agent curriculum
# Stage 1: Learn from random agents
stage1 = collect_episodes(
n=1000,
agent_mix=[(RandomAgent(), 1.0)],
seed=100
)
# Stage 2: Mix random and greedy
stage2 = collect_episodes(
n=1000,
agent_mix=[
(RandomAgent(), 0.5),
(GreedyAgent(), 0.5),
],
seed=200
)
# Stage 3: Mostly expert play
stage3 = collect_episodes(
n=1000,
agent_mix=[
(RandomAgent(), 0.1),
(GreedyAgent(), 0.4),
(HybridAgent(), 0.5),
],
seed=300
)
Custom codec configuration
from game_grammar.codec import EventCodec
from game_grammar.core import Salience
# High-frequency snapshots for debugging
debug_codec = EventCodec(
snapshot_interval=4, # Snapshot every 4 ticks
salience_threshold=Salience.TICK
)
# Low-frequency snapshots for efficiency
efficient_codec = EventCodec(
snapshot_interval=32, # Snapshot every 32 ticks
salience_threshold=Salience.RULE_EFFECT # Only critical events
)
episodes = collect_episodes(
n=100,
agent_mix=[(GreedyAgent(), 1.0)],
codec=efficient_codec
)
Parallel collection
import multiprocessing as mp
from functools import partial
def collect_batch(batch_id, n_per_batch, agent_mix):
return collect_episodes(
n=n_per_batch,
agent_mix=agent_mix,
seed=batch_id * 1000 # Different seed per batch
)
# Collect 10,000 episodes across 10 processes
with mp.Pool(10) as pool:
batches = pool.map(
partial(collect_batch, n_per_batch=1000, agent_mix=agent_mix),
range(10)
)
all_episodes = [ep for batch in batches for ep in batch]
Agent Requirements
Any agent used with play_episode or collect_episodes must implement:
class MyAgent:
def act(self, state: SnakeState, legal_actions: list[Action]) -> Action:
"""Choose an action given current state and legal moves."""
# Your logic here
return chosen_action
See the Agents section for built-in agent implementations and examples.