diff --git a/framework/py/flwr/serverapp/strategy/multikrum.py b/framework/py/flwr/serverapp/strategy/multikrum.py index 2198c36086bf..cdb9d804b024 100644 --- a/framework/py/flwr/serverapp/strategy/multikrum.py +++ b/framework/py/flwr/serverapp/strategy/multikrum.py @@ -171,12 +171,16 @@ def compute_distances(records: list[ArrayRecord]) -> NDArray: A 2D array representing the distance matrix of squared L2 distances between input ArrayRecords """ - # Formula: ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * x.y - # Flatten records and stack them into a matrix - flat_w = np.stack( - [np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records], - axis=0, - ) # shape: (n, d) with n number of records and d the dimension of model + # Flatten each record to a 1D array efficiently + flat_model_arrays = [] + for rec in records: + arrs = rec.to_numpy_ndarrays() + if len(arrs) == 1: + flat = arrs[0].ravel() + else: + flat = np.concatenate(arrs, axis=None).ravel() + flat_model_arrays.append(flat) + flat_w = np.stack(flat_model_arrays, axis=0) # Compute squared norms of each vector norms: NDArray = np.square(flat_w).sum(axis=1) # shape (n,) @@ -220,28 +224,19 @@ def select_multikrum( If `num_nodes_to_select` is set to 1, Multi-Krum reduces to classical Krum and only a single RecordDict is selected. """ - # Construct list of ArrayRecord objects from replies record_key = list(contents[0].array_records.keys())[0] - # Recall aggregate_train first ensures replies only contain one ArrayRecord array_records = [cast(ArrayRecord, reply[record_key]) for reply in contents] distance_matrix = compute_distances(array_records) - # For each node, take the n-f-2 closest parameters vectors num_closest = max(1, len(array_records) - num_malicious_nodes - 2) - closest_indices = [] - for distance in distance_matrix: - closest_indices.append( - np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 - ) + # Get closest indices for each node efficiently using numpy + sorted_indices = np.argsort(distance_matrix, axis=1) + # Remove self (index 0) and get the n-f-2 closest (index 1 to num_closest) + closest_indices = sorted_indices[:, 1 : num_closest + 1] - # Compute the score for each node, that is the sum of the distances - # of the n-f-2 closest parameters vectors - scores = [ - np.sum(distance_matrix[i, closest_indices[i]]) - for i in range(len(distance_matrix)) - ] + # Efficiently compute scores as sum of closest distances (vectorized) + scores = np.take_along_axis(distance_matrix, closest_indices, axis=1).sum(axis=1) # Choose the num_nodes_to_select lowest-scoring nodes (MultiKrum) - # and return their updates best_indices = np.argsort(scores)[:num_nodes_to_select] return [contents[i] for i in best_indices]