Skip to main content
This tutorial walks you through creating your first federated learning (FL) project using Syft-Flwr, from setting up the environment to running your first FL training session.

What You’ll Build

You’ll create a diabetes prediction model trained collaboratively across multiple data owners without centralizing their data. The complete project structure will look like:
fl-diabetes-prediction/
├── fl_diabetes_prediction/
│   ├── __init__.py
│   ├── client_app.py   # Defines your ClientApp
│   ├── server_app.py   # Defines your ServerApp
│   └── task.py         # Defines your model, training and data loading
├── pyproject.toml      # Project metadata and dependencies
└── README.md

Prerequisites

  • Python 3.12 or later
  • Basic understanding of PyTorch and machine learning
  • Access to at least 2 machines or environments (for data owners)
1
Install Syft-Flwr
2
Install the Syft-Flwr framework:
3
pip install syft-flwr
4
Or install from source:
5
pip install "git+https://github.com/OpenMined/syft-flwr.git@main"
6
Create Project Structure
7
Create your project directory and files:
8
mkdir -p fl-diabetes-prediction/fl_diabetes_prediction
cd fl-diabetes-prediction
touch fl_diabetes_prediction/__init__.py
touch fl_diabetes_prediction/client_app.py
touch fl_diabetes_prediction/server_app.py
touch fl_diabetes_prediction/task.py
touch pyproject.toml
9
Define the Model and Training Logic
10
Create your neural network model in task.py:
11
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self, input_dim=6):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.BatchNorm1d(32),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
        )
        self.layer2 = nn.Sequential(
            nn.Linear(32, 24),
            nn.BatchNorm1d(24),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.25),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(24, 16),
            nn.BatchNorm1d(16),
            nn.LeakyReLU(0.1)
        )
        self.output_layer = nn.Sequential(
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.output_layer(x)
        return x

def train(model, train_loader, local_epochs=1):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
    model.train()
    
    for epoch in range(local_epochs):
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def evaluate(model, data_loader):
    model.eval()
    criterion = nn.BCELoss()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def get_weights(model):
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

def set_weights(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)
12
Implement the Data Loading
13
Add the data loading function in task.py. This is the key difference from standard Flower—it loads data from SyftBox:
14
def load_syftbox_dataset():
    """Load dataset from SyftBox private data directory"""
    import pandas as pd
    from syft_flwr.utils import get_syftbox_dataset_path
    
    # Get the private dataset path set by SyftBox
    data_dir = get_syftbox_dataset_path()
    
    # Load train and test data
    train_df = pd.read_csv(data_dir / "train.csv")
    test_df = pd.read_csv(data_dir / "test.csv")
    
    # Process and return DataLoaders
    return dataset_processing(train_df, test_df)
15
The get_syftbox_dataset_path() function retrieves the path to private data that only the data owner can access. This ensures data never leaves the owner’s machine.
16
Create the Client Application
17
Define how clients participate in training in client_app.py:
18
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

from fl_diabetes_prediction.task import (
    Net,
    evaluate,
    get_weights,
    set_weights,
    train,
    load_syftbox_dataset,
)

class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def fit(self, parameters, config):
        set_weights(self.net, parameters)
        train(self.net, self.trainloader)
        return get_weights(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        loss, accuracy = evaluate(self.net, self.testloader)
        return loss, len(self.testloader), {"accuracy": accuracy}

def client_fn(context: Context):
    # Load the private dataset from SyftBox
    train_loader, test_loader = load_syftbox_dataset()
    net = Net()
    return FlowerClient(net, train_loader, test_loader).to_client()

app = ClientApp(client_fn=client_fn)
19
Create the Server Application
20
Define the aggregation strategy in server_app.py:
21
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from syft_flwr.strategy import FedAvgWithModelSaving
from pathlib import Path
import os

from fl_diabetes_prediction.task import Net, get_weights

def weighted_average(metrics):
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

def server_fn(context: Context):
    # Initialize the model
    net = Net()
    params = ndarrays_to_parameters(get_weights(net))
    
    # Set up model save path
    output_dir = os.getenv("OUTPUT_DIR", Path.home() / ".syftbox/rds/")
    save_path = Path(output_dir) / "weights"
    
    # Configure the aggregation strategy
    strategy = FedAvgWithModelSaving(
        save_path=save_path,
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_available_clients=2,
        min_fit_clients=2,
        min_evaluate_clients=2,
        initial_parameters=params,
        evaluate_metrics_aggregation_fn=weighted_average,
    )
    
    num_rounds = context.run_config["num-server-rounds"]
    config = ServerConfig(num_rounds=num_rounds)
    
    return ServerAppComponents(config=config, strategy=strategy)

app = ServerApp(server_fn=server_fn)
22
FedAvgWithModelSaving is a custom strategy that saves the global model to disk after each training round, making it easy to track progress and recover from failures.
23
Configure the Project
24
Create pyproject.toml with project metadata and dependencies:
25
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "fl-diabetes-prediction"
version = "1.0.0"
requires-python = ">=3.12"
description = "Federated Learning for Diabetes Prediction"
license = "Apache-2.0"
dependencies = [
    "flwr-datasets>=0.5.0",
    "torch>=2.8.0",
    "imblearn",
    "pandas",
    "scikit-learn==1.6.1",
    "loguru",
    "syft_flwr",
]

[tool.flwr.app]
publisher = "YourName"

[tool.flwr.app.components]
serverapp = "fl_diabetes_prediction.server_app:app"
clientapp = "fl_diabetes_prediction.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
min-available-clients = 2
min-fit-clients = 2
min-evaluate-clients = 2
fraction-fit = 1.0
fraction-evaluate = 1.0
26
Bootstrap the FL Project
27
Before running, bootstrap your project with Syft-Flwr metadata:
28
import syft_flwr
from pathlib import Path

project_path = Path("./fl-diabetes-prediction")
aggregator_email = "[email protected]"  # Data scientist email
datasite_emails = ["[email protected]", "[email protected]"]  # Data owner emails

syft_flwr.bootstrap(
    project_path,
    aggregator=aggregator_email,
    datasites=datasite_emails
)
29
This creates a main.py file and updates pyproject.toml with:
30
  • Unique app name for this FL run
  • List of participating data owners
  • Aggregator (data scientist) email
  • 31
    Test with Simulation (Optional)
    32
    Before running on real data, test locally with mock data:
    33
    flwr run ./fl-diabetes-prediction
    
    34
    Or test with Syft-Flwr simulation:
    35
    syft_flwr.run(
        project_path,
        mock_paths=["path/to/mock1", "path/to/mock2"]
    )
    
    36
    Submit to Data Owners
    37
    Data scientists submit the FL project to data owners for approval:
    38
    import syft_rds as sy
    
    # Connect to data owner's datasite
    do1_client = sy.init_session(
        host="[email protected]",
        email="[email protected]"
    )
    
    # Submit the job
    do1_client.job.submit(
        name="fl-diabetes-prediction",
        user_code_path=project_path,
        dataset_name="pima-indians-diabetes-database",
        entrypoint="main.py",
    )
    
    39
    Run the FL Server
    40
    Once data owners approve and run the client code, start the aggregation server:
    41
    # Submit to yourself to run the server
    ds_client = sy.init_session(
        host="[email protected]",
        email="[email protected]"
    )
    
    job = ds_client.job.submit(
        name="fl-diabetes-prediction-server",
        user_code_path=project_path,
        entrypoint="main.py",
    )
    
    # Approve and run
    ds_client.job.approve(job)
    ds_client.run_private(job, blocking=True)
    
    42
    Monitor Results
    43
    The aggregated model weights are saved after each round in the weights/ directory:
    44
    weights/
    ├── parameters_round_1.safetensors
    ├── parameters_round_2.safetensors
    └── parameters_round_3.safetensors
    
    45
    View training logs:
    46
    ds_client.job.show_logs(job)
    

    Key Differences from Standard Flower

    Syft-Flwr requires only minimal changes to a standard Flower project:
    1. Data Loading: Use load_syftbox_dataset() instead of loading from public datasets
    2. Bootstrap Step: Run syft_flwr.bootstrap() to configure participants
    3. Communication: Messages are exchanged via file sync instead of network connections

    What’s Next?

    Common Issues

    This means the DATA_DIR environment variable is not set. Make sure:
    • You’re running the code through SyftBox job execution
    • The dataset is properly registered with SyftBox
    • You’re using load_syftbox_dataset() in your client code
    Check that:
    • Data owners have approved the job requests
    • All participants are running their respective code
    • The app_name in pyproject.toml matches across all participants
    • SyftBox is running and syncing files properly
    Verify:
    • The OUTPUT_DIR environment variable points to a writable location
    • You’re using FedAvgWithModelSaving strategy
    • The save path directory exists and has write permissions

    Build docs developers (and LLMs) love