Skip to main content
The Flyte KFPyTorch plugin dispatches @task(task_config=PyTorch(...)) tasks to the Kubeflow training-operator, which manages PyTorchJob Kubernetes resources for distributed training.

Prerequisites

  • A running Kubernetes cluster with Flyte installed
  • helm and kubectl configured
  • GPU nodes (optional but typical for training workloads)

Step 1: Install the Kubeflow training-operator

kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone?ref=v1.7.0"
Verify the training-operator is running:
kubectl get pods -n kubeflow
# NAME                                 READY   STATUS    RESTARTS   AGE
# training-operator-xxxxxxxxx-xxxxx    1/1     Running   0          60s
The training-operator also supports TensorFlow (TFJob), MPI (MPIJob), XGBoost (XGBoostJob), and Paddle (PaddleJob). Installing it once enables all these job types.

Step 2: Enable the PyTorch plugin in Flyte

Create a values-pytorch.yaml override file:
configuration:
  inline:
    tasks:
      task-plugins:
        enabled-plugins:
          - container
          - sidecar
          - k8s-array
          - pytorch
        default-for-task-types:
          - container: container
          - container_array: k8s-array
          - pytorch: pytorch
Apply the override:
helm upgrade flyte-backend flyteorg/flyte-binary \
  --namespace flyte \
  --values values.yaml \
  --values values-pytorch.yaml

Step 3: Write a distributed PyTorch task

Install the flytekit PyTorch plugin:
pip install flytekitplugins-kfpytorch

Single-worker PyTorch task

from flytekit import task, workflow
from flytekitplugins.kfpytorch import PyTorch, Worker


@task(
    task_config=PyTorch(
        worker=Worker(replicas=1),
    )
)
def single_node_training() -> float:
    import torch
    x = torch.tensor([1.0, 2.0, 3.0])
    return x.mean().item()


@workflow
def single_node_wf() -> float:
    return single_node_training()

Multi-worker distributed training

from flytekit import task, workflow, Resources
from flytekitplugins.kfpytorch import PyTorch, Worker, Master


@task(
    task_config=PyTorch(
        master=Master(num_retries=3),
        worker=Worker(
            replicas=4,
            resources=Resources(
                cpu="4",
                mem="16Gi",
                gpu="1",
            ),
        ),
        run_policy=dict(
            clean_pod_policy="Running",
        ),
    ),
    requests=Resources(cpu="2", mem="4Gi"),
    limits=Resources(cpu="4", mem="8Gi"),
)
def distributed_training(num_epochs: int = 10) -> float:
    import os
    import torch
    import torch.distributed as dist

    # Initialize the process group
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Training code here
    local_loss = torch.tensor(float(rank))
    dist.all_reduce(local_loss, op=dist.ReduceOp.SUM)
    avg_loss = local_loss.item() / world_size

    dist.destroy_process_group()
    return avg_loss


@workflow
def distributed_wf(num_epochs: int = 10) -> float:
    return distributed_training(num_epochs=num_epochs)

Elastic training (torchrun)

from flytekitplugins.kfpytorch import Elastic


@task(
    task_config=Elastic(
        nnodes=2,
        nproc_per_node=4,
        start_method="spawn",
    )
)
def elastic_training() -> float:
    import torch
    return torch.cuda.device_count() * 1.0

Gang scheduling (optional)

For distributed training jobs, all worker pods must be scheduled simultaneously to avoid deadlocks. Enable gang scheduling using one of:

Kubernetes scheduler plugins (co-scheduling)

apiVersion: v1
kind: PodTemplate
metadata:
  name: gang-scheduling-template
  namespace: flyte
template:
  spec:
    schedulerName: "scheduler-plugins-scheduler"

Apache YuniKorn

apiVersion: v1
kind: PodTemplate
metadata:
  name: yunikorn-template
  namespace: flyte
template:
  metadata:
    annotations:
      yunikorn.apache.org/task-group-name: ""
      yunikorn.apache.org/task-groups: ""
      yunikorn.apache.org/schedulingPolicyParameters: ""
See the configuration overview for more details on applying PodTemplates.

Verify

# Check training-operator is healthy
kubectl get pods -n kubeflow

# After running a PyTorch task, check PyTorchJob resource
kubectl get pytorchjob -n flytesnacks-development

# Describe the job for detailed status
kubectl describe pytorchjob <job-name> -n flytesnacks-development

# Check worker pod logs
kubectl logs -n flytesnacks-development <worker-pod-name>

Troubleshooting

Common causes:
  • Insufficient GPU nodes — check kubectl get nodes -l cloud.google.com/gke-accelerator or equivalent
  • Missing GPU tolerations — add GPU tolerations to your default PodTemplate
  • Missing NVIDIA device plugin — install nvidia-device-plugin as a DaemonSet
NCCL requires that all worker pods can communicate directly. Ensure:
  • No NetworkPolicy blocks pod-to-pod communication within the namespace
  • Pods can resolve each other’s hostnames via the headless Service created by the training-operator
Set a activeDeadlineSeconds in the run_policy to prevent stuck jobs from consuming resources indefinitely:
task_config=PyTorch(
    worker=Worker(replicas=4),
    run_policy=dict(
        active_deadline_seconds=3600,  # 1 hour timeout
    ),
)

Build docs developers (and LLMs) love