Federated Learning Architecture
This document describes how Oneliac implements federated learning to train diagnosis models across multiple healthcare organizations without centralizing sensitive patient data.
What is Federated Learning?
Federated learning is a distributed machine learning technique where:
- Multiple parties (healthcare providers) collaboratively train a model
- Raw patient data never leaves the provider's infrastructure
- Only encrypted gradients are shared with a central coordinator
- Privacy is preserved through encryption and differential privacy
Federated Learning in Healthcare
Traditional Approach (Problematic)
Hospital A → [Patient Data] ──┐
Hospital B → [Patient Data] ──→ Central Server → Train Model
Hospital C → [Patient Data] ──┘
Problems:
- Patient data exposure during transmission
- HIPAA violations possible
- Data centralization risk
- Regulatory compliance difficult
Oneliac Federated Approach (Privacy-Preserving)
Hospital A → [Gradient Computation] → [Encrypt] ──┐
Hospital B → [Gradient Computation] → [Encrypt] ──→ Secure Aggregate → Update Model
Hospital C → [Gradient Computation] → [Encrypt] ──┘
Benefits:
- Raw patient data never shared
- Only gradients transmitted
- Encryption ensures privacy
- Differential privacy adds noise
- Compliant with HIPAA/GDPR
Federated Learning Pipeline
Step 1: Data Reception
Healthcare Provider
↓
Submits encrypted patient batch
{
patient_id: "patient_123",
encrypted_data: "<fernet-encrypted>",
ipfs_cid: "QmXx...",
data_hash: "sha256..."
}
↓
Oneliac FederatedLearningCoordinator receives request
Step 2: Local Gradient Computation
For each patient in batch:
↓
Load patient medical history
Extract features (symptoms, vital signs, lab results)
Forward pass through diagnosis model
Compute loss against ground truth diagnosis
Backward pass to compute gradients
↓
gradient = ∂loss/∂weights
Step 3: Differential Privacy
Add Gaussian noise to gradients:
↓
noise ~ Normal(0, σ²)
σ = √(ρ/ε²) [zCDP formulation]
↓
noisy_gradient = gradient + noise
↓
Ensures individual patient data cannot be recovered
from gradient inspection
Step 4: Encryption
noisy_gradient
↓
Convert to bytes: gradient.numpy().tobytes()
↓
Encrypt with Fernet (symmetric key)
encrypted_gradient = cipher.encrypt(gradient_bytes)
↓
Only coordinator holds decryption key
Step 5: Secure Aggregation
Collect encrypted gradients from multiple providers:
↓
encrypted_grad_1 (Hospital A)
encrypted_grad_2 (Hospital B)
encrypted_grad_3 (Hospital C)
↓
For each encrypted gradient:
- Decrypt with coordinator key
- Convert to tensor
↓
avg_gradient = mean([grad_1, grad_2, grad_3])
Step 6: Model Update
Global diagnosis model parameters
↓
For each parameter: param -= learning_rate * avg_gradient
↓
Updated model weights
↓
Compute new model hash
↓
Increment round number
Step 7: Return Results
Return to providers:
{
"round": 5,
"participants": 3,
"model_hash": "abc123...",
"privacy_budget_used": 0.15
}
Implementation Details
FederatedLearningCoordinator Class
class FederatedLearningCoordinator:
def __init__(self, model: nn.Module, num_agents: int = 3):
"""
Initialize federated learning coordinator.
Args:
model: Global diagnosis model (PyTorch nn.Module)
num_agents: Maximum participants per round
"""
self.global_model = model
self.num_agents = num_agents
self.round_number = 0
# Encryption setup
self.encryption_key = Fernet.generate_key()
self.cipher = Fernet(self.encryption_key)
# Differential privacy parameters
# zCDP: Zero-Concentrated Differential Privacy
# ρ = privacy budget, ε = epsilon (privacy loss)
self.dp_sigma = np.sqrt(1.0 / 1.0**2) # σ = √(ρ/ε²)
async def train_round(
self,
agent_data: List[PatientData]
) -> Dict:
"""
Execute one federated learning training round.
Args:
agent_data: List of encrypted patient data from providers
Returns:
{
"round": round_number,
"participants": num_participants,
"model_hash": model_hash
}
"""
print(f"[FL] Starting round {self.round_number}")
# Limit participants
participants = min(len(agent_data), self.num_agents)
# Step 1: Compute encrypted gradients
encrypted_gradients = []
for data in agent_data[:participants]:
gradient = await self._compute_encrypted_gradient(data)
encrypted_gradients.append(gradient)
# Step 2: Secure aggregation
aggregated = self._secure_aggregate(encrypted_gradients)
# Step 3: Update global model
self._update_global_model(aggregated)
# Step 4: Record results
self.round_number += 1
return {
"round": self.round_number,
"participants": participants,
"model_hash": self._compute_model_hash()
}
async def _compute_encrypted_gradient(
self,
data: PatientData
) -> bytes:
"""
Compute encrypted gradient from patient data.
Process:
1. Forward pass through model
2. Compute loss
3. Backward pass (compute gradients)
4. Add differential privacy noise
5. Encrypt result
"""
# Simulate gradient computation
gradient = torch.randn(100)
# Step: Add Gaussian noise for differential privacy
noise = torch.normal(
mean=0.0,
std=self.dp_sigma,
size=gradient.shape
)
gradient += noise
# Convert to bytes
gradient_bytes = gradient.numpy().tobytes()
# Encrypt
encrypted = self.cipher.encrypt(gradient_bytes)
return encrypted
def _secure_aggregate(
self,
encrypted_gradients: List[bytes]
) -> bytes:
"""
Securely aggregate encrypted gradients.
Process:
1. Decrypt gradients (coordinator only)
2. Average them
3. Re-encrypt aggregated result
"""
all_grads = []
# Decrypt and collect
for enc_grad in encrypted_gradients:
decrypted = self.cipher.decrypt(enc_grad)
grad_array = np.frombuffer(decrypted, dtype=np.float32)
grad_tensor = torch.tensor(grad_array)
all_grads.append(grad_tensor)
# Average
avg_gradient = torch.mean(torch.stack(all_grads), dim=0)
# Re-encrypt
return self.cipher.encrypt(avg_gradient.numpy().tobytes())
def _update_global_model(self, aggregated_gradient: bytes):
"""Update global model with aggregated gradient."""
decrypted = self.cipher.decrypt(aggregated_gradient)
grad_array = np.frombuffer(decrypted, dtype=np.float32)
gradient = torch.tensor(grad_array)
# Gradient descent update
with torch.no_grad():
for param in self.global_model.parameters():
if param.numel() == gradient.numel():
# param -= learning_rate * gradient
param -= 0.01 * gradient.view_as(param)
break
def _compute_model_hash(self) -> str:
"""Compute SHA256 hash of current model weights."""
state_dict = self.global_model.state_dict()
model_bytes = json.dumps(
{k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
).encode()
return hashlib.sha256(model_bytes).hexdigest()
Privacy Analysis
Differential Privacy Guarantees
Oneliac uses zCDP (zero-concentrated differential privacy) to bound privacy loss:
Definition: An algorithm is (ρ, ε)-differentially private if for all
neighboring datasets D, D':
E[exp(ε · L)] ≤ exp(ρ)
where L is the privacy loss.
Example Privacy Budget
For a single training round with Gaussian noise (σ = 1.0):
# Privacy parameters
rho = 1.0 # Privacy loss
epsilon = 1.0 # Epsilon threshold
sigma = np.sqrt(rho / epsilon**2) # ≈ 1.0
# After 10 rounds with composition
total_rho = 10 * rho # = 10.0
total_epsilon = np.sqrt(2 * total_rho * np.log(1/delta)) # ≈ 5.4
# For delta = 10^-6
This means: After 10 rounds, the algorithm provides (5.4, 10^-6)-differential privacy.
What Differential Privacy Prevents
- Member Inference Attack: Cannot determine if specific patient was in training set
- Model Inversion: Cannot recover training data from model weights
- Attribute Inference: Cannot infer sensitive attributes of individual patients
Data Flow Diagram
┌─────────────────────────────┐
│ Hospital A's System │
│ Patient Data: │
│ - Demographics │
│ - Symptoms: fever, cough │
│ - Vitals: BP 120/80 │
│ - Labs: WBC 12k │
└─────────────────────────────┘
↓
Encrypt (Fernet)
↓
┌─────────────────────────────────────────────────┐
│ Encrypted Data transmitted to Coordinator │
│ (Network eavesdropping sees only ciphertext) │
└─────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────┐
│ Coordinator (Oneliac) │
│ - Decrypt (only coordinator has key) │
│ - Extract features from medical history │
│ - Forward pass: model(features) → diagnosis │
│ - Compute loss: cross_entropy(pred, truth) │
│ - Backward pass: loss.backward() → gradients │
│ - Add noise: gradient += Normal(0, σ) │
│ - Encrypt: cipher.encrypt(gradient_bytes) │
└─────────────────────────────────────────────────┘
↓
Encrypted gradients from:
Hospital A, Hospital B, Hospital C
↓
┌─────────────────────────────────────────────────┐
│ Secure Aggregation │
│ 1. Decrypt all gradients │
│ 2. avg = mean([grad_A, grad_B, grad_C]) │
│ 3. Encrypt aggregated gradient │
└─────────────────────────────────────────────────┘
↓
Update Global Model
param -= 0.01 * avg_gradient
↓
Send updated model weights back to hospitals
(Also protected by differential privacy)
Convergence Analysis
Theoretical Guarantees
Federated Learning with differential privacy has convergence guarantees:
Theorem (Approximate Convergence):
If each round processes m samples and noise σ is bounded,
then after T rounds:
E[|| ∇ f(θ_T) ||²] ≤ O(1/T + σ²/m)
- First term: Standard SGD convergence
- Second term: Privacy cost (decreases with more samples)
Practical Convergence
In practice:
- 10-20 rounds: Model converges to reasonable accuracy
- 50+ rounds: Very high accuracy with privacy guarantees
- More participants: Faster convergence, better model
Monitoring and Metrics
Key Metrics to Track
# Example monitoring during training
metrics = {
"round": 5,
"num_participants": 3,
"total_samples": 1500, # Samples from all 3 hospitals
"avg_gradient_norm": 2.4,
"model_loss": 0.32,
"privacy_budget_used": 0.15, # Out of 1.0
"time_per_round": 2.3, # seconds
}
Privacy Budget Tracking
class PrivacyBudgetTracker:
def __init__(self, epsilon_total: float = 1.0):
self.epsilon_total = epsilon_total
self.epsilon_used = 0.0
def log_round(self, sigma: float, num_samples: int):
"""Log privacy consumption for a round."""
epsilon_this_round = self._compute_privacy_cost(sigma)
self.epsilon_used += epsilon_this_round
print(f"Epsilon used this round: {epsilon_this_round:.4f}")
print(f"Total epsilon used: {self.epsilon_used:.4f}")
print(f"Budget remaining: {self.epsilon_total - self.epsilon_used:.4f}")
Best Practices
- Key Management: Store encryption keys separately from encrypted data
- Round Coordination: Establish agreed-upon schedules for training rounds
- Model Validation: Test updated models on held-out data before deployment
- Privacy Auditing: Regularly compute and log privacy budget consumption
- Participant Verification: Ensure all participants are authorized healthcare providers
- Error Handling: Gracefully handle dropped participants during aggregation
Limitations
- Communication Overhead: Transmitting gradients requires bandwidth
- Model Size: Large models → large gradients → slow transmission
- Synchronization: All participants must be ready for each round
- Privacy-Accuracy Trade-off: More noise → better privacy → worse accuracy