The Flyte KFPyTorch plugin dispatches @task(task_config=PyTorch(...)) tasks to the Kubeflow training-operator , which manages PyTorchJob Kubernetes resources for distributed training.
Prerequisites
A running Kubernetes cluster with Flyte installed
helm and kubectl configured
GPU nodes (optional but typical for training workloads)
Step 1: Install the Kubeflow training-operator
kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone?ref=v1.7.0"
Verify the training-operator is running:
kubectl get pods -n kubeflow
# NAME READY STATUS RESTARTS AGE
# training-operator-xxxxxxxxx-xxxxx 1/1 Running 0 60s
The training-operator also supports TensorFlow (TFJob), MPI (MPIJob), XGBoost (XGBoostJob), and Paddle (PaddleJob). Installing it once enables all these job types.
Step 2: Enable the PyTorch plugin in Flyte
Create a values-pytorch.yaml override file:
configuration :
inline :
tasks :
task-plugins :
enabled-plugins :
- container
- sidecar
- k8s-array
- pytorch
default-for-task-types :
- container : container
- container_array : k8s-array
- pytorch : pytorch
configmap :
enabled_plugins :
tasks :
task-plugins :
enabled-plugins :
- container
- sidecar
- k8s-array
- pytorch
default-for-task-types :
container : container
sidecar : sidecar
container_array : k8s-array
pytorch : pytorch
Apply the override:
helm upgrade flyte-backend flyteorg/flyte-binary \
--namespace flyte \
--values values.yaml \
--values values-pytorch.yaml
Step 3: Write a distributed PyTorch task
Install the flytekit PyTorch plugin:
pip install flytekitplugins-kfpytorch
Single-worker PyTorch task
from flytekit import task, workflow
from flytekitplugins.kfpytorch import PyTorch, Worker
@task (
task_config = PyTorch(
worker = Worker( replicas = 1 ),
)
)
def single_node_training () -> float :
import torch
x = torch.tensor([ 1.0 , 2.0 , 3.0 ])
return x.mean().item()
@workflow
def single_node_wf () -> float :
return single_node_training()
Multi-worker distributed training
from flytekit import task, workflow, Resources
from flytekitplugins.kfpytorch import PyTorch, Worker, Master
@task (
task_config = PyTorch(
master = Master( num_retries = 3 ),
worker = Worker(
replicas = 4 ,
resources = Resources(
cpu = "4" ,
mem = "16Gi" ,
gpu = "1" ,
),
),
run_policy = dict (
clean_pod_policy = "Running" ,
),
),
requests = Resources( cpu = "2" , mem = "4Gi" ),
limits = Resources( cpu = "4" , mem = "8Gi" ),
)
def distributed_training ( num_epochs : int = 10 ) -> float :
import os
import torch
import torch.distributed as dist
# Initialize the process group
dist.init_process_group( backend = "nccl" )
rank = dist.get_rank()
world_size = dist.get_world_size()
# Training code here
local_loss = torch.tensor( float (rank))
dist.all_reduce(local_loss, op = dist.ReduceOp. SUM )
avg_loss = local_loss.item() / world_size
dist.destroy_process_group()
return avg_loss
@workflow
def distributed_wf ( num_epochs : int = 10 ) -> float :
return distributed_training( num_epochs = num_epochs)
Elastic training (torchrun)
from flytekitplugins.kfpytorch import Elastic
@task (
task_config = Elastic(
nnodes = 2 ,
nproc_per_node = 4 ,
start_method = "spawn" ,
)
)
def elastic_training () -> float :
import torch
return torch.cuda.device_count() * 1.0
Gang scheduling (optional)
For distributed training jobs, all worker pods must be scheduled simultaneously to avoid deadlocks. Enable gang scheduling using one of:
Kubernetes scheduler plugins (co-scheduling)
apiVersion : v1
kind : PodTemplate
metadata :
name : gang-scheduling-template
namespace : flyte
template :
spec :
schedulerName : "scheduler-plugins-scheduler"
Apache YuniKorn
apiVersion : v1
kind : PodTemplate
metadata :
name : yunikorn-template
namespace : flyte
template :
metadata :
annotations :
yunikorn.apache.org/task-group-name : ""
yunikorn.apache.org/task-groups : ""
yunikorn.apache.org/schedulingPolicyParameters : ""
See the configuration overview for more details on applying PodTemplates.
Verify
# Check training-operator is healthy
kubectl get pods -n kubeflow
# After running a PyTorch task, check PyTorchJob resource
kubectl get pytorchjob -n flytesnacks-development
# Describe the job for detailed status
kubectl describe pytorchjob < job-nam e > -n flytesnacks-development
# Check worker pod logs
kubectl logs -n flytesnacks-development < worker-pod-nam e >
Troubleshooting
Workers stuck in Pending state
Common causes:
Insufficient GPU nodes — check kubectl get nodes -l cloud.google.com/gke-accelerator or equivalent
Missing GPU tolerations — add GPU tolerations to your default PodTemplate
Missing NVIDIA device plugin — install nvidia-device-plugin as a DaemonSet
NCCL communication errors
NCCL requires that all worker pods can communicate directly. Ensure:
No NetworkPolicy blocks pod-to-pod communication within the namespace
Pods can resolve each other’s hostnames via the headless Service created by the training-operator
Set a activeDeadlineSeconds in the run_policy to prevent stuck jobs from consuming resources indefinitely: task_config = PyTorch(
worker = Worker( replicas = 4 ),
run_policy = dict (
active_deadline_seconds = 3600 , # 1 hour timeout
),
)