Spaces:
Sleeping
Sleeping
| """aggregator.py module.""" | |
| import tensorflow as tf | |
| from typing import List, Dict | |
| import numpy as np | |
| from collections import defaultdict | |
| import logging | |
| class FederatedAggregator: | |
| def __init__(self, config: Dict): | |
| logger = logging.getLogger(__name__) | |
| logger.debug(f"Initializing FederatedAggregator with config: {config}") | |
| # Defensive: try to find aggregation config | |
| agg_config = None | |
| if 'aggregation' in config: | |
| agg_config = config['aggregation'] | |
| elif 'server' in config and 'aggregation' in config['server']: | |
| agg_config = config['server']['aggregation'] | |
| else: | |
| logger.error(f"No 'aggregation' key found in config passed to FederatedAggregator: {config}") | |
| raise KeyError("'aggregation' config section is required for FederatedAggregator") | |
| self.weighted = agg_config.get('weighted', True) | |
| logger.info(f"FederatedAggregator initialized. Weighted: {self.weighted}") | |
| def federated_averaging(self, updates: List[Dict]) -> List: | |
| """Perform federated averaging (FedAvg) on model weights.""" | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Performing federated averaging on {len(updates)} client updates") | |
| if not updates: | |
| logger.warning("No updates provided for federated averaging") | |
| return None | |
| # Calculate total samples across all clients | |
| total_samples = sum(update['size'] for update in updates) | |
| logger.debug(f"Total samples across clients: {total_samples}") | |
| # Initialize aggregated weights with zeros | |
| first_weights = updates[0]['weights'] | |
| aggregated_weights = [np.zeros_like(w) for w in first_weights] | |
| # Weighted average of model weights | |
| for update in updates: | |
| client_weights = update['weights'] | |
| client_size = update['size'] | |
| weight_factor = client_size / total_samples if self.weighted else 1.0 / len(updates) | |
| logger.debug(f"Client {update['client_id']}: size={client_size}, weight_factor={weight_factor}") | |
| # Add weighted contribution to aggregated weights | |
| for i, (agg_w, client_w) in enumerate(zip(aggregated_weights, client_weights)): | |
| aggregated_weights[i] += np.array(client_w) * weight_factor | |
| logger.info("Federated averaging completed successfully") | |
| return aggregated_weights | |
| def compute_metrics(self, client_metrics: List[Dict]) -> Dict: | |
| logger = logging.getLogger(__name__) | |
| logger.debug(f"Computing metrics for {len(client_metrics)} clients") | |
| if not client_metrics: | |
| logger.warning("No client metrics provided to compute_metrics.") | |
| return {} | |
| aggregated_metrics = defaultdict(float) | |
| total_samples = sum(metrics['num_samples'] for metrics in client_metrics) | |
| logger.debug(f"Total samples across clients: {total_samples}") | |
| for metrics in client_metrics: | |
| weight = metrics['num_samples'] / total_samples if self.weighted else 1.0 | |
| logger.debug(f"Client metrics: {metrics}, weight: {weight}") | |
| for metric_name, value in metrics['metrics'].items(): | |
| aggregated_metrics[metric_name] += value * weight | |
| logger.info(f"Aggregated metrics: {dict(aggregated_metrics)}") | |
| return dict(aggregated_metrics) | |
| def check_convergence(self, | |
| old_weights: List, | |
| new_weights: List, | |
| threshold: float = 1e-5) -> bool: | |
| logger = logging.getLogger(__name__) | |
| logger.debug("Checking convergence...") | |
| if old_weights is None or new_weights is None: | |
| logger.warning("Old or new weights are None in check_convergence.") | |
| return False | |
| weight_differences = [ | |
| np.mean(np.abs(old - new)) | |
| for old, new in zip(old_weights, new_weights) | |
| ] | |
| logger.debug(f"Weight differences: {weight_differences}") | |
| converged = all(diff < threshold for diff in weight_differences) | |
| logger.info(f"Convergence status: {converged}") | |
| return converged | |