Loading...
Loading...
Federated learning with Deep Q-Networks for privacy-preserving optimization
npx skill4agent add kinhluan/skills federated-learning-dqn┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Hospital A │ │ Hospital B │ │ Hospital C │
│ Local DQN │ │ Local DQN │ │ Local DQN │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
└───────────────────┼───────────────────┘
│
┌──────▼──────┐
│ Aggregator │
│ (Server) │
└─────────────┘# Server
def federated_averaging(models, weights):
total = sum(weights)
averaged = {}
for key in models[0].state_dict():
averaged[key] = sum(
w * model.state_dict()[key]
for model, w in zip(models, weights)
) / total
return averaged
# Round
for round in range(num_rounds):
clients = select_clients()
models, weights = [], []
for client in clients:
model, weight = client.train(local_epochs)
models.append(model)
weights.append(weight)
global_model.load_state_dict(federated_averaging(models, weights))import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim)
)
def forward(self, x):
return self.net(x)def train_dqn(agent, replay_buffer, target_net):
for step in range(num_steps):
state = env.reset()
done = False
while not done:
# Epsilon-greedy action
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
# Store transition
replay_buffer.push(state, action, reward, next_state, done)
# Sample batch
batch = replay_buffer.sample(batch_size)
# Compute loss
q_values = agent(batch.state)
next_q_values = target_net(batch.next_state)
target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# Update target network
if step % target_update == 0:
target_net.load_state_dict(agent.state_dict())class MLFQScheduler:
def __init__(self, num_queues=3):
self.queues = [[] for _ in range(num_queues)]
self.priority_boost = 10
def add_patient(self, patient, priority):
queue_idx = min(priority, len(self.queues) - 1)
self.queues[queue_idx].append(patient)
def get_next_patient(self):
# DQN selects which queue to serve
queue_state = self.get_queue_state()
action = dqn_agent.select_action(queue_state)
# Boost priority of waiting patients
self.boost_priorities()
return self.queues[action].pop(0) if self.queues[action] else None
def boost_priorities(self):
for i in range(len(self.queues) - 1, 0, -1):
for patient in self.queues[i]:
if patient.wait_time > self.priority_boost:
self.queues[i-1].append(patient)
self.queues[i].remove(patient)def add_dp_noise(gradients, epsilon, delta, sensitivity):
"""Add Gaussian noise for (ε,δ)-differential privacy"""
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
noise = torch.randn_like(gradients) * sigma
return gradients + noisestate = {
'queue_lengths': [len(q) for q in queues], # Shape: (num_queues,)
'patient_acuity': average_acuity_per_queue, # Shape: (num_queues,)
'resource_availability': [beds, staff, equipment],
'time_features': [hour_of_day, day_of_week],
'predicted_arrivals': next_hour_forecast,
}actions = {
0: 'Schedule from high-priority queue',
1: 'Schedule from medium-priority queue',
2: 'Schedule from low-priority queue',
3: 'Allocate additional resource',
4: 'Request transfer from other hospital',
}def calculate_reward(state, action, next_state):
reward = 0
# Minimize wait time (weighted by acuity)
reward -= sum(
patient.wait_time * patient.acuity
for patient in all_patients
)
# Penalize queue imbalance
reward -= variance(queue_lengths) * 10
# Reward completing high-acuity cases
reward += completed_high_acuity * 50
# Penalize resource overutilization
if resource_utilization > threshold:
reward -= overutilization_penalty
return reward| Metric | Description |
|---|---|
| Privacy Budget (ε) | Differential privacy guarantee |
| Model Accuracy | Comparison to centralized training |
| Communication Rounds | Convergence speed |
| Patient Wait Time | Scheduling effectiveness |
| Resource Utilization | System efficiency |