Skip to main content
The generate() function provides a simple interface for generating molecules from text prompts using ChemLactica models.

Function Signature

from chemlactica.generation.generation import generate

def generate(
    prompts: List[str],
    model,
    **gen_kwargs
) -> Dict[str, List[str]]
Source: generation/generation.py:9

Parameters

prompts
List[str] | str
required
Input prompt(s) for molecule generation. Can be a single string or list of strings.Each prompt should follow the ChemLactica format with property tags and [START_SMILES] token.
model
PreTrainedModel
required
The loaded ChemLactica model (e.g., from AutoModelForCausalLM.from_pretrained()).
gen_kwargs
dict
Generation parameters passed to model.generate(). Common parameters:
  • max_new_tokens (int): Maximum tokens to generate
  • temperature (float): Sampling temperature (0.0-1.0)
  • do_sample (bool): Whether to use sampling vs greedy decoding
  • num_return_sequences (int): Number of sequences to generate per prompt
  • repetition_penalty (float): Penalty for repeating tokens
  • top_p (float): Nucleus sampling threshold
  • top_k (int): Top-k sampling threshold

Returns

generation_dict
Dict[str, List[str]]
Dictionary mapping each prompt to a list of generated completions.Each completion includes the full generated text from the model.

Basic Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.generation.generation import generate

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "yerevann/chemlactica-125m",
    torch_dtype=torch.bfloat16
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(
    "yerevann/chemlactica-125m"
)

# Single prompt
prompt = "</s>[SAS]2.5[/SAS][START_SMILES]"
results = generate(
    prompts=prompt,
    model=model,
    max_new_tokens=128,
    temperature=1.0,
    do_sample=True
)

print(results[prompt][0])  # First generation

Batch Generation

# Multiple prompts
prompts = [
    "</s>[SAS]2.5[/SAS][QED]0.8[/QED][START_SMILES]",
    "</s>[MW]300[/MW][TPSA]50[/TPSA][START_SMILES]",
    "</s>[SIMILAR]CC(=O)OC1=CC=CC=C1C(=O)O 0.7[/SIMILAR][START_SMILES]"
]

results = generate(
    prompts=prompts,
    model=model,
    max_new_tokens=128,
    temperature=1.0,
    do_sample=True,
    num_return_sequences=5  # Generate 5 molecules per prompt
)

# Access results
for prompt, generations in results.items():
    print(f"Prompt: {prompt}")
    for i, gen in enumerate(generations):
        print(f"  Generation {i+1}: {gen}")

Generation Parameters

Temperature

# Low temperature (more focused)
results = generate(
    prompts=prompt,
    model=model,
    temperature=0.5,  # More conservative
    do_sample=True
)

# High temperature (more diverse)
results = generate(
    prompts=prompt,
    model=model,
    temperature=1.5,  # More exploratory
    do_sample=True
)

Multiple Sequences

# Generate multiple molecules
results = generate(
    prompts=prompt,
    model=model,
    max_new_tokens=128,
    num_return_sequences=10,  # 10 molecules
    do_sample=True,
    temperature=1.0
)

print(f"Generated {len(results[prompt])} molecules")

Extracting SMILES

The generated text includes the full completion. Extract SMILES strings:
import re

def extract_smiles(generated_text):
    """Extract SMILES from generated text."""
    match = re.search(
        r'\[START_SMILES\](.*?)\[END_SMILES\]',
        generated_text
    )
    if match:
        return match.group(1)
    return None

# Extract SMILES from generations
for prompt, generations in results.items():
    for gen in generations:
        smiles = extract_smiles(gen)
        if smiles:
            print(f"SMILES: {smiles}")

Notes

The generate() function is a simple wrapper around Hugging Face’s model.generate(). For production use in molecular optimization, consider using the optimize() function which includes additional features like pool management and oracle scoring.
Make sure the tokenizer is configured with padding_side="left" for batch generation to work correctly.

Command-Line Usage

The generation module also includes a command-line interface:
python -m chemlactica.generation.generation \
  --prompts "</s>[SAS]2.5[/SAS][START_SMILES]" \
  --checkpoint_path yerevann/chemlactica-125m \
  --device cuda \
  --max_new_tokens 128 \
  --do_sample \
  --temperature 1.0

See Also

Build docs developers (and LLMs) love