Why Custom Strategies?
While FedAvg (Federated Averaging) works well for many scenarios, you might need custom strategies for:- Weighted aggregation based on data quality or quantity
- Differential privacy with noise injection
- Robust aggregation that handles outliers or malicious clients
- Personalized models with client-specific parameters
- Custom checkpointing for fault tolerance
Built-in Strategy: FedAvgWithModelSaving
Syft-Flwr provides a custom strategy that extends Flower’s FedAvg with automatic model checkpointing.Source Code
Here’s the complete implementation fromsyft_flwr/strategy/fedavg.py:
Key Features
- Automatic Checkpointing: Saves model after each round
- SafeTensors Format: Uses safe, fast serialization format
- Error Handling: Checks directory exists before saving
- Extensible: Easy to customize for your needs
Creating a Custom Strategy
Let’s walk through implementing custom strategies for different use cases.from flwr.server.strategy import FedAvg
from flwr.common import (
Parameters,
FitRes,
EvaluateRes,
parameters_to_ndarrays,
ndarrays_to_parameters,
)
from typing import List, Tuple, Optional, Dict
from loguru import logger
class CustomStrategy(FedAvg):
"""Template for custom FL strategy."""
def __init__(self, *args, **kwargs):
# Initialize your custom parameters
super().__init__(*args, **kwargs)
logger.info("Initialized CustomStrategy")
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate model updates from clients."""
# Custom aggregation logic here
# Default: call parent's FedAvg implementation
return super().aggregate_fit(server_round, results, failures)
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[BaseException],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation metrics from clients."""
# Custom metric aggregation here
return super().aggregate_evaluate(server_round, results, failures)
from flwr.common import Scalar
import numpy as np
class WeightedFedAvg(FedAvg):
"""FedAvg with custom weighted aggregation."""
def __init__(self, quality_weights: Dict[str, float] = None, *args, **kwargs):
"""
Args:
quality_weights: Dict mapping client IDs to quality scores (0-1)
"""
super().__init__(*args, **kwargs)
self.quality_weights = quality_weights or {}
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate with quality-weighted averaging."""
if not results:
return None, {}
# Convert results to numpy arrays
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
# Calculate weights: dataset size * quality score
total_weight = 0
weighted_params = None
for idx, (client, fit_res) in enumerate(results):
# Get quality weight (default to 1.0 if not specified)
quality = self.quality_weights.get(client.cid, 1.0)
# Combined weight: size * quality
weight = fit_res.num_examples * quality
total_weight += weight
# Get parameters
params = parameters_to_ndarrays(fit_res.parameters)
# Weighted sum
if weighted_params is None:
weighted_params = [w * weight for w in params]
else:
weighted_params = [
wp + (w * weight)
for wp, w in zip(weighted_params, params)
]
# Average
aggregated = [w / total_weight for w in weighted_params]
logger.info(
f"Round {server_round}: Aggregated {len(results)} clients "
f"with total weight {total_weight:.2f}"
)
return ndarrays_to_parameters(aggregated), {}
class DPFedAvg(FedAvg):
"""FedAvg with Differential Privacy."""
def __init__(
self,
noise_multiplier: float = 0.1,
clip_norm: float = 1.0,
*args,
**kwargs
):
"""
Args:
noise_multiplier: Scale of Gaussian noise to add
clip_norm: Gradient clipping threshold
"""
super().__init__(*args, **kwargs)
self.noise_multiplier = noise_multiplier
self.clip_norm = clip_norm
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate with differential privacy."""
# First, perform standard aggregation
parameters, metrics = super().aggregate_fit(
server_round, results, failures
)
if parameters is None:
return None, metrics
# Add Gaussian noise for privacy
params = parameters_to_ndarrays(parameters)
noisy_params = []
for param in params:
# Clip gradients
norm = np.linalg.norm(param)
if norm > self.clip_norm:
param = param * (self.clip_norm / norm)
# Add Gaussian noise
noise = np.random.normal(
0,
self.noise_multiplier * self.clip_norm,
param.shape
)
noisy_params.append(param + noise)
logger.info(
f"Round {server_round}: Added DP noise "
f"(σ={self.noise_multiplier}, C={self.clip_norm})"
)
return ndarrays_to_parameters(noisy_params), metrics
class RobustFedAvg(FedAvg):
"""FedAvg with robust aggregation using median."""
def __init__(self, trimmed_mean_ratio: float = 0.1, *args, **kwargs):
"""
Args:
trimmed_mean_ratio: Fraction of extreme values to trim (0-0.5)
"""
super().__init__(*args, **kwargs)
self.trim_ratio = trimmed_mean_ratio
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate using trimmed mean (remove outliers)."""
if not results:
return None, {}
# Extract all parameter arrays
all_params = [
parameters_to_ndarrays(fit_res.parameters)
for _, fit_res in results
]
# For each layer, compute trimmed mean
num_layers = len(all_params[0])
aggregated = []
for layer_idx in range(num_layers):
# Stack all clients' parameters for this layer
layer_params = np.array([
params[layer_idx] for params in all_params
])
# Calculate how many to trim from each end
num_trim = int(len(layer_params) * self.trim_ratio)
if num_trim > 0:
# Sort along client axis and trim extremes
sorted_params = np.sort(layer_params, axis=0)
trimmed = sorted_params[num_trim:-num_trim]
# Average the remaining
layer_avg = np.mean(trimmed, axis=0)
else:
# Just average if no trimming
layer_avg = np.mean(layer_params, axis=0)
aggregated.append(layer_avg)
logger.info(
f"Round {server_round}: Robust aggregation with "
f"{num_trim} clients trimmed from each end"
)
return ndarrays_to_parameters(aggregated), {}
import json
from datetime import datetime
class AdvancedCheckpointing(FedAvgWithModelSaving):
"""Strategy with enhanced checkpointing and metrics tracking."""
def __init__(self, save_path: str, *args, **kwargs):
super().__init__(save_path, *args, **kwargs)
self.metrics_history = []
def _save_global_model(self, server_round: int, parameters):
"""Save model with additional metadata."""
# Save model weights
super()._save_global_model(server_round, parameters)
# Save metadata
metadata = {
"round": server_round,
"timestamp": datetime.now().isoformat(),
"num_parameters": sum(
p.size for p in parameters_to_ndarrays(parameters)
),
}
metadata_file = self.save_path / f"metadata_round_{server_round}.json"
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Metadata saved to: {metadata_file}")
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[BaseException],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation and track metrics history."""
loss, metrics = super().aggregate_evaluate(
server_round, results, failures
)
# Track metrics over time
round_metrics = {
"round": server_round,
"loss": loss,
"metrics": metrics,
"num_clients": len(results),
}
self.metrics_history.append(round_metrics)
# Save metrics history
history_file = self.save_path / "metrics_history.json"
with open(history_file, "w") as f:
json.dump(self.metrics_history, f, indent=2)
return loss, metrics
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from pathlib import Path
import os
from your_project.task import Net, get_weights
from your_project.strategies import WeightedFedAvg # Your custom strategy
def server_fn(context: Context):
# Initialize model
net = Net()
params = ndarrays_to_parameters(get_weights(net))
# Define quality weights for clients
quality_weights = {
"client_1": 1.0, # High quality
"client_2": 0.8, # Medium quality
"client_3": 0.6, # Lower quality
}
# Use custom strategy
strategy = WeightedFedAvg(
quality_weights=quality_weights,
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=params,
)
num_rounds = context.run_config["num-server-rounds"]
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(config=config, strategy=strategy)
app = ServerApp(server_fn=server_fn)
Strategy Comparison
| Strategy | Use Case | Pros | Cons |
|---|---|---|---|
| FedAvg | General purpose | Simple, well-tested | Treats all clients equally |
| FedAvgWithModelSaving | Production FL | Checkpointing, recovery | Small overhead |
| WeightedFedAvg | Heterogeneous data quality | Accounts for quality | Requires quality metrics |
| DPFedAvg | Privacy-sensitive data | Privacy guarantees | Reduced accuracy |
| RobustFedAvg | Adversarial clients | Outlier resistant | Slower convergence |
Advanced Techniques
Client Selection
Select clients based on custom criteria:Adaptive Learning Rates
Adjust aggregation based on training progress:Personalization
Maintain both global and local models:Testing Custom Strategies
Test your strategy with simulation:test_strategy.py
Best Practices
Start Simple
Begin with FedAvg or FedAvgWithModelSaving before implementing complex strategies.
Test Thoroughly
Use simulation mode to validate your strategy before deploying to real clients.
Log Everything
Add detailed logging to understand aggregation behavior during training.
Handle Failures
Always check for empty results and handle client failures gracefully.
What’s Next?
- Explore Flower’s built-in strategies
- Learn about federated analytics patterns
- Try advanced FL techniques like FedProx and FedOpt
Common Pitfalls
Forgetting to call super().__init__()
Forgetting to call super().__init__()
Always call the parent class initializer to set up base functionality:
Not handling empty results
Not handling empty results
Always check if results is empty before aggregating:
Mismatched parameter shapes
Mismatched parameter shapes
Ensure all clients return parameters with the same shape. Log shapes when debugging: