Spaces:
Sleeping
Sleeping
"""coordinator.py module.""" | |
import tensorflow as tf | |
from typing import List, Dict, Any, Optional | |
import numpy as np | |
from collections import defaultdict | |
import logging | |
import time | |
import threading | |
from .aggregator import FederatedAggregator | |
class FederatedCoordinator: | |
def __init__(self, config: Dict): | |
"""Initialize the federated learning coordinator.""" | |
logger = logging.getLogger(__name__) | |
logger.debug(f"Initializing FederatedCoordinator with config: {config}") | |
self.config = config | |
self.clients = {} | |
self.client_updates = {} # Store updates for current round | |
self.global_model_weights = None | |
self.current_round = 0 | |
self.training_active = False | |
# Extract federated learning parameters | |
self.min_clients = config.get('federated', {}).get('min_clients', 2) | |
self.rounds = config.get('federated', {}).get('rounds', 10) | |
# Debug: log config structure | |
logger.debug(f"Coordinator received config: {config}") | |
# Robustly extract aggregation config | |
agg_config = None | |
if 'aggregation' in config: | |
agg_config = config | |
elif 'server' in config and 'aggregation' in config['server']: | |
agg_config = config['server'] | |
else: | |
logger.error(f"No 'aggregation' key found in config for FederatedAggregator: {config}") | |
raise ValueError("'aggregation' config section is required for FederatedAggregator") | |
logger.debug(f"Passing aggregation config to FederatedAggregator: {agg_config}") | |
try: | |
self.aggregator = FederatedAggregator(agg_config) | |
except Exception as e: | |
logger.error(f"Error initializing FederatedAggregator: {e}") | |
raise | |
# Initialize global model weights with random values | |
self._initialize_global_model() | |
self.lock = threading.Lock() # Thread safety for concurrent API calls | |
logger.info("FederatedCoordinator initialized.") | |
def _initialize_global_model(self): | |
"""Initialize global model weights with random values.""" | |
logger = logging.getLogger(__name__) | |
try: | |
# Build a simple model to get initial weights | |
input_dim = self.config.get('model', {}).get('input_dim', 32) | |
hidden_layers = self.config.get('model', {}).get('hidden_layers', [128, 64]) | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Input(shape=(input_dim,)), | |
tf.keras.layers.Dense(hidden_layers[0], activation='relu'), | |
tf.keras.layers.Dense(hidden_layers[1], activation='relu'), | |
tf.keras.layers.Dense(1) | |
]) | |
model.compile(optimizer='adam', loss='mse') | |
self.global_model_weights = model.get_weights() | |
logger.info(f"Global model initialized with {len(self.global_model_weights)} weight layers") | |
except Exception as e: | |
logger.error(f"Error initializing global model: {e}") | |
# Fallback to simple random weights | |
self.global_model_weights = [ | |
np.random.randn(32, 128).astype(np.float32), | |
np.random.randn(128).astype(np.float32), | |
np.random.randn(128, 64).astype(np.float32), | |
np.random.randn(64).astype(np.float32), | |
np.random.randn(64, 1).astype(np.float32), | |
np.random.randn(1).astype(np.float32) | |
] | |
logger.info("Using fallback random weights for global model") | |
def register_client(self, client_id: str, client_info: Dict[str, Any] = None) -> bool: | |
"""Register a new client.""" | |
with self.lock: | |
if client_id in self.clients: | |
logging.getLogger(__name__).warning(f"Client {client_id} already registered") | |
return True | |
self.clients[client_id] = { | |
'info': client_info or {}, | |
'last_seen': time.time(), | |
'metrics': defaultdict(list) | |
} | |
logging.getLogger(__name__).info(f"Client {client_id} registered successfully") | |
return True | |
def get_client_config(self) -> Dict[str, Any]: | |
"""Get configuration to send to clients""" | |
return { | |
'model_config': self.config.get('model', {}), | |
'training_config': self.config.get('training', {}), | |
'current_round': self.current_round, | |
'total_rounds': self.rounds | |
} | |
def get_global_model(self) -> Optional[List]: | |
"""Get the current global model weights""" | |
with self.lock: | |
return self.global_model_weights | |
def receive_model_update(self, client_id: str, model_weights: List, metrics: Dict[str, Any]): | |
"""Receive a model update from a client""" | |
with self.lock: | |
if client_id not in self.clients: | |
raise ValueError(f"Client {client_id} not registered") | |
self.client_updates[client_id] = { | |
'weights': model_weights, | |
'metrics': metrics, | |
'timestamp': time.time() | |
} | |
self.clients[client_id]['last_seen'] = time.time() | |
logger = logging.getLogger(__name__) | |
logger.info(f"Received update from client {client_id}") | |
# Check if we have enough updates for aggregation | |
if len(self.client_updates) >= self.min_clients: | |
self._aggregate_models() | |
def _aggregate_models(self): | |
"""Aggregate models from all client updates""" | |
try: | |
logger = logging.getLogger(__name__) | |
logger.info(f"Aggregating models from {len(self.client_updates)} clients") | |
# Prepare updates for aggregation | |
updates = [] | |
for client_id, update in self.client_updates.items(): | |
client_size = update['metrics'].get('dataset_size', 100) # Default size | |
updates.append({ | |
'client_id': client_id, | |
'weights': update['weights'], | |
'size': client_size | |
}) | |
# Aggregate using FedAvg | |
self.global_model_weights = self.aggregator.federated_averaging(updates) | |
# Clear updates for next round | |
self.client_updates.clear() | |
self.current_round += 1 | |
logger.info(f"Model aggregation completed for round {self.current_round}") | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
logger.error(f"Error during model aggregation: {str(e)}") | |
def _count_active_clients(self) -> int: | |
"""Count active clients (seen in last 60 seconds)""" | |
current_time = time.time() | |
active_count = sum(1 for client in self.clients.values() | |
if current_time - client['last_seen'] < 60) | |
return active_count | |
def start(self): | |
"""Start the federated learning process with API server""" | |
logger = logging.getLogger(__name__) | |
# Print server startup information | |
logger.info("\n" + "=" * 60) | |
logger.info(f"{'Federated Learning Server Starting':^60}") | |
logger.info("=" * 60) | |
# Print configuration details | |
logger.info("\nServer Configuration:") | |
logger.info("-" * 30) | |
logger.info(f"Minimum clients required: {self.min_clients}") | |
logger.info(f"Total rounds planned: {self.rounds}") | |
active_clients_count = self._count_active_clients() | |
logger.info(f"Current active clients: {active_clients_count}") | |
logger.info("-" * 30 + "\n") | |
self.training_active = True | |
# Import and start API server | |
try: | |
from ..api.server import FederatedAPI | |
api_config = self.config.get('api', {}) | |
host = api_config.get('host', '0.0.0.0') | |
port = api_config.get('port', 8080) | |
api_server = FederatedAPI(self, host, port) | |
api_thread = api_server.run_threaded() | |
logger.info(f"API server started on {host}:{port}") | |
# Keep server running | |
try: | |
while self.training_active and self.current_round < self.rounds: | |
time.sleep(1) # Keep main thread alive | |
# Log progress periodically | |
active_clients_count = self._count_active_clients() | |
if active_clients_count > 0: | |
logger.debug(f"Round {self.current_round}/{self.rounds}, " | |
f"Active Clients: {active_clients_count}, " | |
f"Updates: {len(self.client_updates)}") | |
logger.info("Federated learning completed successfully") | |
except KeyboardInterrupt: | |
logger.info("Server shutdown requested") | |
self.training_active = False | |
except ImportError as e: | |
logger.error(f"Failed to start API server: {str(e)}") | |
# Fallback to original behavior | |
# ...existing code... | |