Skip to main content

Overview

ChemLactica models can generate novel molecules based on conditional prompts that specify desired properties and structural similarities. The generation process uses the model’s understanding of SMILES notation and molecular properties learned during pretraining.

Generation Function

The core generation functionality is provided by the generate() function in chemlactica/generation/generation.py:
generation.py:9
from chemlactica.generation.generation import generate

def generate(prompts: List[str], model, **gen_kwargs):
    """
    Generate molecules from prompts.
    
    Args:
        prompts: List of formatted prompt strings or single prompt
        model: Loaded ChemLactica model
        **gen_kwargs: Generation parameters (see Sampling Strategies)
    
    Returns:
        Dictionary mapping prompts to list of generated sequences
    """

Basic Usage

Load Model and Generate

from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.generation.generation import generate

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "yerevann/chemlactica-125m",
    torch_dtype=torch.bfloat16
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(
    "yerevann/chemlactica-125m",
    padding_side="left"
)

# Create a prompt (see Prompting section)
prompt = "</s>[SAS]2.25[/SAS][SIMILAR]CC(=O)OC1=CC=CC=C1C(=O)O 0.62[/SIMILAR][START_SMILES]"

# Generate molecules
generation_dict = generate(
    prompts=[prompt],
    model=model,
    max_new_tokens=100,
    do_sample=True,
    temperature=1.0,
    repetition_penalty=1.0
)

Generation in Optimization

During molecular optimization, generation happens within the optimization loop in chemlactica/mol_opt/optimization.py:
optimization.py:102
# Create prompts from optimization entries
prompts = [
    optim_entry.to_prompt(
        is_generation=True,
        include_oracle_score=prev_train_iter != 0,
        config=config,
        max_score=max_score
    )
    for optim_entry in optim_entries
]

# Tokenize prompts
data = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Truncate to prevent exceeding context length
for key, value in data.items():
    data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:]

# Generate molecules
output = model.generate(
    **data,
    **config["generation_config"]
)

# Decode outputs
output_texts = tokenizer.batch_decode(output)

Extracting Generated SMILES

After generation, extract the SMILES string from the output:
optimization.py:39
def create_molecule_entry(output_text, validate_smiles):
    """Extract SMILES from generated text."""
    start_smiles_tag, end_smiles_tag = "[START_SMILES]", "[END_SMILES]"
    start_ind = output_text.rfind(start_smiles_tag)
    end_ind = output_text.rfind(end_smiles_tag)
    
    if start_ind == -1 or end_ind == -1:
        return None
    
    generated_smiles = output_text[start_ind+len(start_smiles_tag):end_ind]
    
    if not validate_smiles(generated_smiles):
        return None
    
    if len(generated_smiles) == 0:
        return None
    
    try:
        molecule = MoleculeEntry(smiles=generated_smiles)
        return molecule
    except:
        return None
The model generates text between [START_SMILES] and [END_SMILES] tags. Always validate the generated SMILES using RDKit before using them.

Generation Parameters

Key parameters for molecule generation:
max_new_tokens
int
default:"100"
Maximum number of new tokens to generate. SMILES strings typically require 50-100 tokens.
do_sample
bool
default:"true"
Whether to use sampling. Set to true for diverse molecule generation.
temperature
float
default:"1.0"
Sampling temperature. Higher values (1.0-1.5) increase diversity. See Sampling Strategies.
repetition_penalty
float
default:"1.0"
Penalty for repeating tokens. Usually kept at 1.0 for molecule generation.
eos_token_id
int
default:"20"
Token ID that signals end of generation. ChemLactica uses token ID 20.

Batch Generation

For efficient generation, process multiple prompts in batches:
# Generate multiple molecules in parallel
prompts = [
    "</s>[SAS]2.0[/SAS][START_SMILES]",
    "</s>[SAS]3.0[/SAS][START_SMILES]",
    "</s>[QED]0.9[/QED][START_SMILES]"
]

generation_dict = generate(
    prompts=prompts,
    model=model,
    max_new_tokens=100,
    do_sample=True,
    temperature=1.0
)

for prompt, outputs in generation_dict.items():
    print(f"Prompt: {prompt}")
    for output in outputs:
        print(f"  Generated: {output}")

Context Length Management

ChemLactica models have a maximum context length of 2048 tokens:
optimization.py:114
# Truncate input to fit within context window
for key, value in data.items():
    data[key] = value[:, -2048 + config["generation_config"]["max_new_tokens"]:]
Always ensure your prompt length + max_new_tokens does not exceed 2048 tokens to avoid truncation issues.

Example: Complete Generation Pipeline

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.mol_opt.utils import MoleculeEntry
from rdkit import Chem

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "yerevann/chemlactica-125m",
    torch_dtype=torch.bfloat16
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")

# Create prompt for aspirin-like molecule with SAS ~2.25
aspirin_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
prompt = f"</s>[SAS]2.25[/SAS][SIMILAR]{aspirin_smiles} 0.62[/SIMILAR][START_SMILES]"

# Generate
data = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(
    input_ids=data.input_ids,
    max_new_tokens=100,
    do_sample=True,
    temperature=1.0,
    eos_token_id=20
)

# Decode and extract SMILES
generated_text = tokenizer.decode(output[0])
start_idx = generated_text.rfind("[START_SMILES]") + len("[START_SMILES]")
end_idx = generated_text.rfind("[END_SMILES]")
smiles = generated_text[start_idx:end_idx]

# Validate with RDKit
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
    print(f"Valid molecule generated: {smiles}")
    molecule_entry = MoleculeEntry(smiles=smiles)
else:
    print("Invalid SMILES generated")

Next Steps

Prompting

Learn how to format prompts with properties and similarity constraints

Sampling Strategies

Optimize generation quality with sampling parameters

Build docs developers (and LLMs) love