Environment
Base abstract class for creating RL environments to train and evaluate LLMs.
Overview
The Environment class provides the core infrastructure for:
- Managing datasets (training and evaluation)
- Running rollouts with LLM clients
- Scoring rollouts with rubrics
- Handling state lifecycle and cleanup
- Token usage tracking
All custom environments must inherit from this class and implement the rollout() method.
Inheritance Hierarchy
Environment (abstract)
├── SingleTurnEnv
├── MultiTurnEnv
│ ├── ToolEnv
│ │ └── StatefulToolEnv
│ └── [Custom MultiTurn Environments]
└── EnvGroup
Constructor
Environment(
dataset: Dataset | DatasetBuilder | None = None,
eval_dataset: Dataset | DatasetBuilder | None = None,
system_prompt: str | None = None,
few_shot: Messages | None = None,
parser: Parser | None = None,
rubric: Rubric | None = None,
sampling_args: SamplingArgs | None = None,
message_type: MessageType | object = _MESSAGE_TYPE_UNSET,
tool_defs: list[Tool] | None = None,
max_workers: int = 512,
env_id: str | None = None,
env_args: dict | None = None,
map_kwargs: dict = {},
max_seq_len: int | None = None,
score_rollouts: bool = True,
pass_threshold: float = 0.5,
**kwargs
)
Parameters
dataset
Dataset | DatasetBuilder | None
Training dataset or a callable that returns a dataset. Either dataset or eval_dataset must be provided.
eval_dataset
Dataset | DatasetBuilder | None
Evaluation dataset or a callable that returns a dataset.
System prompt to prepend to all conversations.
Few-shot examples to include in prompts.
Parser for extracting structured data from completions. Defaults to Parser().
Rubric for scoring rollouts. Defaults to Rubric().
Default sampling arguments for generation (temperature, top_p, etc.).
Provider-agnostic tool definitions in vf.Tool format.
Maximum number of worker threads for synchronous execution.
Unique identifier for this environment.
Additional environment-specific arguments.
Keyword arguments to pass to HuggingFace dataset .map() operations.
Maximum sequence length for tokenization and truncation.
Whether to score rollouts using the rubric.
Reward threshold for considering a rollout as “passed”.
Core Methods
rollout
async def rollout(
input: RolloutInput,
client: Client,
model: str,
sampling_args: SamplingArgs | None = None
) -> State
Run a single rollout for a given input. Must be implemented by subclasses.
Input data from the dataset containing prompt, answer, etc.
LLM client for making API calls.
Model identifier (e.g., “gpt-4”, “claude-3-5-sonnet”).
Optional sampling arguments to override defaults.
Returns: State - Final state after rollout completion.
get_model_response
async def get_model_response(
state: State,
prompt: Messages | str,
client: Client | None = None,
model: str | None = None,
tool_defs: list[Tool] | None = None,
sampling_args: SamplingArgs | None = None
) -> Response
Get model response for a given prompt (chat or completion).
Prompt as messages or string.
Client to use (defaults to state["client"]).
Model to use (defaults to state["model"]).
Tools available for this request (defaults to state["tool_defs"]).
Sampling arguments (defaults to state["sampling_args"]).
Returns: Response - Model response with message, usage, etc.
init_state
async def init_state(
input: RolloutInput,
client: Client | ClientConfig,
model: str,
sampling_args: SamplingArgs | None = None
) -> State
Create initial state from dataset input. Called automatically at the start of each rollout.
Input data from the dataset.
Client or client configuration.
Returns: State - Initialized state with input fields, client, model, etc.
Dataset Methods
build_dataset
def build_dataset() -> Dataset | None
Build and cache the training dataset from source if needed.
Returns: Dataset | None - Built dataset or None if no source.
build_eval_dataset
def build_eval_dataset() -> Dataset | None
Build and cache the evaluation dataset from source if needed.
Returns: Dataset | None - Built dataset or None if no source.
get_dataset
def get_dataset(n: int = -1, seed: int | None = None) -> Dataset
Get the training dataset, optionally shuffled and limited.
Maximum number of examples to return. -1 returns all.
Random seed for shuffling.
Returns: Dataset - Training dataset.
get_eval_dataset
def get_eval_dataset(n: int = -1, seed: int | None = None) -> Dataset
Get the evaluation dataset, optionally shuffled and limited. Falls back to training dataset if no eval dataset exists.
Maximum number of examples to return. -1 returns all.
Random seed for shuffling.
Returns: Dataset - Evaluation dataset.
Generation & Evaluation
generate
async def generate(
inputs: Dataset | List[RolloutInput],
client: Client | ClientConfig,
model: str,
sampling_args: SamplingArgs | None = None,
max_concurrent: int = -1,
results_path: Path | None = None,
state_columns: list[str] | None = None,
save_results: bool = False,
push_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | list[ProgressCallback] | None = None,
on_log: LogCallback | None = None
) -> GenerateOutputs
Generate rollouts for a set of inputs.
inputs
Dataset | List[RolloutInput]
Input examples to generate rollouts for.
LLM client or client configuration.
Sampling arguments to override defaults.
Maximum concurrent rollouts. -1 for unlimited.
Path to save/resume results.
Additional state fields to include in outputs.
Whether to save results to disk.
Whether to push results to HuggingFace Hub.
Dataset name for HuggingFace Hub.
Score rollouts independently vs. in groups.
Maximum retries for failed rollouts.
Callback when generation starts.
on_progress
ProgressCallback | list[ProgressCallback] | None
Progress callback(s). None uses default tqdm progress bar.
Returns: GenerateOutputs - Dictionary with outputs and metadata keys.
generate_sync
def generate_sync(
inputs: Dataset | List[RolloutInput],
client: Client | ClientConfig,
**kwargs
) -> GenerateOutputs
Synchronous wrapper for generate(). Handles event loop creation.
evaluate
async def evaluate(
client: Client | ClientConfig,
model: str,
sampling_args: SamplingArgs | None = None,
num_examples: int = -1,
rollouts_per_example: int = 1,
max_concurrent: int = -1,
results_path: Path | None = None,
state_columns: list[str] | None = None,
save_results: bool = False,
push_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | list[ProgressCallback] | None = None,
on_log: LogCallback | None = None,
**kwargs
) -> GenerateOutputs
Evaluate model on the environment’s evaluation dataset.
LLM client or client configuration.
Number of examples to evaluate. -1 for all.
Number of rollouts to generate per example.
Other parameters are the same as generate().
Returns: GenerateOutputs - Dictionary with outputs and metadata keys.
evaluate_sync
def evaluate_sync(
client: Client | ClientConfig,
model: str,
sampling_args: SamplingArgs | None = None,
num_examples: int = -1,
rollouts_per_example: int = 1,
max_concurrent: int = -1,
results_path: Path | None = None,
state_columns: list[str] | None = None,
save_results: bool = False,
push_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0
) -> GenerateOutputs
Synchronous wrapper for evaluate().
Token Usage Tracking
get_state_usage
def get_state_usage(state: State) -> TokenUsage | None
Get token usage statistics for a state.
Returns: TokenUsage | None - Dictionary with input_tokens and output_tokens keys, or None.
increment_state_usage
def increment_state_usage(
state: State,
input_tokens: int | float = 0,
output_tokens: int | float = 0
) -> None
Manually increment token usage for a state.
increment_state_usage_from_response
def increment_state_usage_from_response(
state: State,
response: object
) -> None
Extract and increment token usage from a response object.
State Lifecycle
is_completed
async def is_completed(state: State, **kwargs) -> bool
Check all stop conditions. Sets state["is_completed"] = True if any condition is met.
Returns: bool - True if any stop condition is met.
Configuration
set_kwargs
def set_kwargs(**kwargs) -> None
Set environment attributes using setter methods when available.
add_rubric
def add_rubric(rubric: Rubric) -> None
Add a rubric to the environment. Creates a RubricGroup if a rubric already exists.
set_max_seq_len
def set_max_seq_len(max_seq_len: int | None) -> None
Set the maximum sequence length.
set_score_rollouts
def set_score_rollouts(score_rollouts: bool) -> None
Set whether to score rollouts.
Server Methods
start_server
async def start_server(
address: str | None = None,
extra_env_kwargs: dict[str, Any] | None = None,
log_level: str | None = None,
log_file: str | None = None,
log_file_level: str | None = None,
health_check_interval: float = 1.0,
startup_timeout: float = 600.0,
recovery_timeout: float = 600.0
) -> None
This method is subject to change. External users should avoid depending on it directly.
Start a ZMQ server process for distributed rollout execution.
stop_server
async def stop_server() -> None
This method is subject to change. External users should avoid depending on it directly.
Stop the ZMQ server process.
Static Methods
make_dataset
@staticmethod
def make_dataset(...) -> Dataset
Utility for creating HuggingFace datasets. See verifiers.utils.save_utils.make_dataset for details.
Example Usage
import verifiers as vf
from datasets import load_dataset
# Create a simple environment
class MyEnv(vf.Environment):
async def rollout(
self,
input: vf.RolloutInput,
client: vf.Client,
model: str,
sampling_args: vf.SamplingArgs | None = None,
) -> vf.State:
state = await self.init_state(input, client, model, sampling_args)
# Get model response
response = await self.get_model_response(
state,
prompt=state["prompt"]
)
# Store completion
state["completion"] = response.message
state["is_completed"] = True
return state
# Load environment with dataset
def load_environment():
dataset = load_dataset("gsm8k", "main", split="train")
def reward_fn(answer: str, completion: vf.Messages) -> float:
# Custom reward logic
return 1.0 if answer in str(completion) else 0.0
return MyEnv(
dataset=dataset,
rubric=vf.Rubric(reward_fn),
system_prompt="You are a helpful assistant."
)
# Evaluate
env = load_environment()
results = await env.evaluate(
client=vf.ClientConfig(
provider="openai",
api_key="sk-..."
),
model="gpt-4",
num_examples=10
)
print(f"Average reward: {results['metadata']['avg_reward']}")
See Also