Skip to main content

Overview

Fine-tuning allows you to customize pre-trained open-source models for your specific tasks and domain. Vertex AI provides managed fine-tuning services that handle infrastructure provisioning, distributed training, and hyperparameter optimization.

Fine-Tuning Methods

Vertex AI supports multiple fine-tuning approaches:
Update all model parameters for maximum customization:
  • Best for: Domain-specific tasks requiring significant adaptation
  • Resource requirements: High (requires powerful GPUs)
  • Training time: Hours to days
  • Model quality: Highest potential quality

Supervised Fine-Tuning (SFT)

Preparing Your Dataset

Format your training data in JSONL (JSON Lines) format:
{"messages": [{"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}]}
{"messages": [{"role": "user", "content": "Explain photosynthesis"}, {"role": "assistant", "content": "Photosynthesis is the process by which plants convert light energy into chemical energy..."}]}

Dataset Requirements

  • Training set: At least 100 examples (1,000+ recommended)
  • Validation set: Less than 25% of training data and under 5,000 examples
  • Format: JSONL with consistent schema
  • Storage: Upload to Google Cloud Storage

Upload Data to Cloud Storage

import os
from google.cloud import storage

# Upload to GCS
BUCKET_URI = "gs://your-bucket-name"
train_file_uri = f"{BUCKET_URI}/datasets/train.jsonl"
validation_file_uri = f"{BUCKET_URI}/datasets/validation.jsonl"

# Using gcloud CLI
!gcloud storage cp train.jsonl {train_file_uri}
!gcloud storage cp validation.jsonl {validation_file_uri}

Full Fine-Tuning Example

Setup and Configuration

1

Install Dependencies

pip install --upgrade google-cloud-aiplatform>=1.129.0
2

Initialize Vertex AI

import vertexai
from vertexai.tuning import SourceModel, sft

PROJECT_ID = "your-project-id"
LOCATION = "us-central1"
BUCKET_URI = "gs://your-bucket"

vertexai.init(
    project=PROJECT_ID,
    location=LOCATION,
    staging_bucket=BUCKET_URI
)
3

Configure Fine-Tuning Job

from pydantic import BaseModel, Field

class TuningConfig(BaseModel):
    base_model: str = Field(
        default="meta/[email protected]",
        description="Base model to fine-tune"
    )
    tuning_mode: str = Field(
        default="FULL",
        description="FULL or LORA"
    )
    epochs: int = Field(
        default=3,
        description="Number of training epochs"
    )
    learning_rate: float = Field(
        default=2e-5,
        description="Learning rate"
    )

config = TuningConfig()
4

Launch Training

import uuid

output_uri = f"{BUCKET_URI}/tuning-output/{uuid.uuid4()}"

source_model = SourceModel(base_model=config.base_model)

sft_tuning_job = sft.train(
    source_model=source_model,
    tuning_mode=config.tuning_mode,
    epochs=config.epochs,
    learning_rate=config.learning_rate,
    train_dataset=train_file_uri,
    validation_dataset=validation_file_uri,
    output_uri=output_uri
)
5

Monitor Progress

import time
from google.cloud.aiplatform_v1beta1.types import JobState

print("Training started. Monitoring progress...")

while sft_tuning_job.state not in [
    JobState.JOB_STATE_CANCELLED,
    JobState.JOB_STATE_FAILED,
    JobState.JOB_STATE_SUCCEEDED,
]:
    time.sleep(600)  # Check every 10 minutes
    sft_tuning_job.refresh()
    print(f"Status: {sft_tuning_job.state.name}")

print(f"Training completed: {sft_tuning_job.state.name}")

LoRA Fine-Tuning

LoRA is more efficient for most use cases:
# Configure LoRA fine-tuning
sft_tuning_job = sft.train(
    source_model=SourceModel(base_model="meta/[email protected]"),
    tuning_mode="LORA",  # Use LoRA instead of FULL
    epochs=3,
    learning_rate=3e-4,  # Higher learning rate for LoRA
    train_dataset=train_file_uri,
    validation_dataset=validation_file_uri,
    output_uri=output_uri,
    # LoRA-specific parameters
    lora_rank=8,  # Rank of LoRA matrices
    lora_alpha=16,  # Scaling factor
    lora_dropout=0.05  # Dropout rate
)

Advanced Fine-Tuning with TRL

Use Hugging Face’s TRL (Transformer Reinforcement Learning) library for advanced techniques:

Using Custom Training Scripts

from google.cloud import aiplatform

# Define custom training job
job = aiplatform.CustomTrainingJob(
    display_name="gemma-trl-finetuning",
    container_uri="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-1.py310:latest",
    requirements=["trl", "transformers", "datasets", "peft", "bitsandbytes"]
)

# Run training
model = job.run(
    replica_count=1,
    machine_type="n1-standard-16",
    accelerator_type="NVIDIA_TESLA_V100",
    accelerator_count=2,
    args=[
        "--model_name", "google/gemma-2b",
        "--dataset", train_file_uri,
        "--output_dir", output_uri,
        "--num_epochs", "3",
        "--learning_rate", "2e-5"
    ]
)

Real-World Example: MetaMath Fine-Tuning

Fine-tune a model for mathematical reasoning:
1

Load MetaMathQA Dataset

from datasets import load_dataset

# Load dataset from Hugging Face
dataset = load_dataset("meta-math/MetaMathQA")["train"]

# Split into train/validation
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
train_split = split_dataset["train"]
validation_split = split_dataset["test"]

# Limit validation to 5000 examples
if len(validation_split) > 5000:
    validation_split = validation_split.shuffle(seed=42).select(range(4999))
2

Format Data

METAMATH_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:"""

def format_for_tuning(example):
    query = example["query"]
    response = example["response"]
    instruction = METAMATH_TEMPLATE.format(instruction=query)
    
    return {
        "messages": [
            {"role": "user", "content": instruction},
            {"role": "assistant", "content": f" {response}"}
        ]
    }

train_formatted = train_split.map(format_for_tuning, remove_columns=train_split.column_names)
val_formatted = validation_split.map(format_for_tuning, remove_columns=validation_split.column_names)
3

Save and Upload

import json

def save_to_jsonl(dataset, output_path):
    with open(output_path, "w") as f:
        for example in dataset:
            json.dump(example, f)
            f.write("\n")

save_to_jsonl(train_formatted, "metamath_train.jsonl")
save_to_jsonl(val_formatted, "metamath_val.jsonl")

# Upload to GCS
!gcloud storage cp metamath_train.jsonl {BUCKET_URI}/datasets/
!gcloud storage cp metamath_val.jsonl {BUCKET_URI}/datasets/
4

Fine-Tune Model

sft_tuning_job = sft.train(
    source_model=SourceModel(base_model="meta/[email protected]"),
    tuning_mode="FULL",
    epochs=3,
    learning_rate=2e-5,
    train_dataset=f"{BUCKET_URI}/datasets/metamath_train.jsonl",
    validation_dataset=f"{BUCKET_URI}/datasets/metamath_val.jsonl",
    output_uri=f"{BUCKET_URI}/metamath-output"
)

Deploy Fine-Tuned Models

Deploy from Training Output

from vertexai.preview import model_garden
import os

# Model artifacts location
model_artifacts_uri = os.path.join(
    output_uri,
    "postprocess/node-0/checkpoints/final"
)

# Create custom model
tuned_model = model_garden.CustomModel(gcs_uri=model_artifacts_uri)

# Deploy to endpoint
endpoint = tuned_model.deploy(
    machine_type="g2-standard-12",
    accelerator_type="NVIDIA_L4",
    accelerator_count=1,
    min_replica_count=1,
    max_replica_count=3
)

Test the Fine-Tuned Model

# Make prediction
prompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."

instruction = "James buys 5 packs of beef that are 4 pounds each. The price of beef is $5.50 per pound. How much did he pay?"

prediction = endpoint.predict(
    instances=[{
        "prompt": prompt_template.format(instruction=instruction),
        "max_tokens": 250,
        "temperature": 0.2,
        "top_p": 1.0,
        "top_k": 1
    }]
)

print(prediction.predictions[0])

Serving Multiple LoRA Adapters

Serve multiple LoRA adapters with a single base model:
from huggingface_hub import snapshot_download

# Download LoRA adapters
sql_adapter = snapshot_download(
    repo_id="google-cloud-partnership/gemma-2-2b-it-lora-sql",
    local_dir="./adapters/sql"
)

code_adapter = snapshot_download(
    repo_id="google-cloud-partnership/gemma-2-2b-it-lora-magicoder",
    local_dir="./adapters/code"
)

# Upload to GCS
!gcloud storage cp -r ./adapters/* {BUCKET_URI}/lora-adapters/

# Deploy with vLLM supporting multiple adapters
# See serving documentation for detailed configuration

Hyperparameter Tuning

Optimize training hyperparameters:
# Full fine-tuning: 1e-5 to 5e-5
# LoRA: 1e-4 to 5e-4
learning_rates = [1e-5, 2e-5, 5e-5]

for lr in learning_rates:
    job = sft.train(
        source_model=source_model,
        learning_rate=lr,
        # ... other parameters
    )

Evaluation

Evaluate your fine-tuned model:
from vertexai.preview.generative_models import GenerativeModel
import pandas as pd

# Load test dataset
test_data = pd.read_json("test.jsonl", lines=True)

# Evaluate
results = []
for idx, row in test_data.iterrows():
    prediction = endpoint.predict(
        instances=[{"prompt": row["prompt"], "max_tokens": 200}]
    )
    results.append({
        "prompt": row["prompt"],
        "expected": row["completion"],
        "predicted": prediction.predictions[0]
    })

# Calculate metrics
results_df = pd.DataFrame(results)
results_df.to_csv("evaluation_results.csv", index=False)

Best Practices

Data Quality

Use high-quality, diverse training data (1,000+ examples recommended)

Start with LoRA

Begin with LoRA fine-tuning before attempting full fine-tuning

Monitor Training

Use validation loss to detect overfitting and adjust epochs

Version Control

Track experiments with clear naming and metadata

Cost Management

Use spot VMs for training jobs to reduce costs

Evaluation First

Always evaluate before deploying to production

Cost Optimization

# Use spot instances for training
job = sft.train(
    source_model=source_model,
    # ... other parameters
    enable_spot_vm=True,  # Up to 80% cost savings
    spot_vm_retention_time=3600  # 1 hour retention
)

Next Steps

Deploy Models

Learn about optimized serving with vLLM and TGI

Example Notebooks

Explore fine-tuning examples on GitHub

Evaluation

Evaluate model quality with Vertex AI Evaluation

Model Garden

Browse pre-trained models to fine-tune

Build docs developers (and LLMs) love