Overview
ChemLactica models can generate novel molecules based on conditional prompts that specify desired properties and structural similarities. The generation process uses the model’s understanding of SMILES notation and molecular properties learned during pretraining.
Generation Function
The core generation functionality is provided by the generate() function in chemlactica/generation/generation.py:
from chemlactica.generation.generation import generate
def generate ( prompts : List[ str ], model , ** gen_kwargs ):
"""
Generate molecules from prompts.
Args:
prompts: List of formatted prompt strings or single prompt
model: Loaded ChemLactica model
**gen_kwargs: Generation parameters (see Sampling Strategies)
Returns:
Dictionary mapping prompts to list of generated sequences
"""
Basic Usage
Load Model and Generate
from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.generation.generation import generate
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"yerevann/chemlactica-125m" ,
torch_dtype = torch.bfloat16
).to( "cuda" )
tokenizer = AutoTokenizer.from_pretrained(
"yerevann/chemlactica-125m" ,
padding_side = "left"
)
# Create a prompt (see Prompting section)
prompt = "</s>[SAS]2.25[/SAS][SIMILAR]CC(=O)OC1=CC=CC=C1C(=O)O 0.62[/SIMILAR][START_SMILES]"
# Generate molecules
generation_dict = generate(
prompts = [prompt],
model = model,
max_new_tokens = 100 ,
do_sample = True ,
temperature = 1.0 ,
repetition_penalty = 1.0
)
Generation in Optimization
During molecular optimization, generation happens within the optimization loop in chemlactica/mol_opt/optimization.py:
# Create prompts from optimization entries
prompts = [
optim_entry.to_prompt(
is_generation = True ,
include_oracle_score = prev_train_iter != 0 ,
config = config,
max_score = max_score
)
for optim_entry in optim_entries
]
# Tokenize prompts
data = tokenizer(prompts, return_tensors = "pt" , padding = True ).to(model.device)
# Truncate to prevent exceeding context length
for key, value in data.items():
data[key] = value[:, - 2048 + config[ "generation_config" ][ "max_new_tokens" ]:]
# Generate molecules
output = model.generate(
** data,
** config[ "generation_config" ]
)
# Decode outputs
output_texts = tokenizer.batch_decode(output)
After generation, extract the SMILES string from the output:
def create_molecule_entry ( output_text , validate_smiles ):
"""Extract SMILES from generated text."""
start_smiles_tag, end_smiles_tag = "[START_SMILES]" , "[END_SMILES]"
start_ind = output_text.rfind(start_smiles_tag)
end_ind = output_text.rfind(end_smiles_tag)
if start_ind == - 1 or end_ind == - 1 :
return None
generated_smiles = output_text[start_ind + len (start_smiles_tag):end_ind]
if not validate_smiles(generated_smiles):
return None
if len (generated_smiles) == 0 :
return None
try :
molecule = MoleculeEntry( smiles = generated_smiles)
return molecule
except :
return None
The model generates text between [START_SMILES] and [END_SMILES] tags. Always validate the generated SMILES using RDKit before using them.
Generation Parameters
Key parameters for molecule generation:
Maximum number of new tokens to generate. SMILES strings typically require 50-100 tokens.
Whether to use sampling. Set to true for diverse molecule generation.
Penalty for repeating tokens. Usually kept at 1.0 for molecule generation.
Token ID that signals end of generation. ChemLactica uses token ID 20.
Batch Generation
For efficient generation, process multiple prompts in batches:
# Generate multiple molecules in parallel
prompts = [
"</s>[SAS]2.0[/SAS][START_SMILES]" ,
"</s>[SAS]3.0[/SAS][START_SMILES]" ,
"</s>[QED]0.9[/QED][START_SMILES]"
]
generation_dict = generate(
prompts = prompts,
model = model,
max_new_tokens = 100 ,
do_sample = True ,
temperature = 1.0
)
for prompt, outputs in generation_dict.items():
print ( f "Prompt: { prompt } " )
for output in outputs:
print ( f " Generated: { output } " )
Context Length Management
ChemLactica models have a maximum context length of 2048 tokens:
# Truncate input to fit within context window
for key, value in data.items():
data[key] = value[:, - 2048 + config[ "generation_config" ][ "max_new_tokens" ]:]
Always ensure your prompt length + max_new_tokens does not exceed 2048 tokens to avoid truncation issues.
Example: Complete Generation Pipeline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from chemlactica.mol_opt.utils import MoleculeEntry
from rdkit import Chem
# Load model
model = AutoModelForCausalLM.from_pretrained(
"yerevann/chemlactica-125m" ,
torch_dtype = torch.bfloat16
).to( "cuda" )
tokenizer = AutoTokenizer.from_pretrained( "yerevann/chemlactica-125m" )
# Create prompt for aspirin-like molecule with SAS ~2.25
aspirin_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
prompt = f "</s>[SAS]2.25[/SAS][SIMILAR] { aspirin_smiles } 0.62[/SIMILAR][START_SMILES]"
# Generate
data = tokenizer(prompt, return_tensors = "pt" ).to( "cuda" )
output = model.generate(
input_ids = data.input_ids,
max_new_tokens = 100 ,
do_sample = True ,
temperature = 1.0 ,
eos_token_id = 20
)
# Decode and extract SMILES
generated_text = tokenizer.decode(output[ 0 ])
start_idx = generated_text.rfind( "[START_SMILES]" ) + len ( "[START_SMILES]" )
end_idx = generated_text.rfind( "[END_SMILES]" )
smiles = generated_text[start_idx:end_idx]
# Validate with RDKit
mol = Chem.MolFromSmiles(smiles)
if mol is not None :
print ( f "Valid molecule generated: { smiles } " )
molecule_entry = MoleculeEntry( smiles = smiles)
else :
print ( "Invalid SMILES generated" )
Next Steps
Prompting Learn how to format prompts with properties and similarity constraints
Sampling Strategies Optimize generation quality with sampling parameters