Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions framework/py/flwr/serverapp/strategy/multikrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,16 @@ def compute_distances(records: list[ArrayRecord]) -> NDArray:
A 2D array representing the distance matrix of squared L2 distances
between input ArrayRecords
"""
# Formula: ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * x.y
# Flatten records and stack them into a matrix
flat_w = np.stack(
[np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records],
axis=0,
) # shape: (n, d) with n number of records and d the dimension of model
# Flatten each record to a 1D array efficiently
flat_model_arrays = []
for rec in records:
arrs = rec.to_numpy_ndarrays()
if len(arrs) == 1:
flat = arrs[0].ravel()
else:
flat = np.concatenate(arrs, axis=None).ravel()
flat_model_arrays.append(flat)
flat_w = np.stack(flat_model_arrays, axis=0)

# Compute squared norms of each vector
norms: NDArray = np.square(flat_w).sum(axis=1) # shape (n,)
Expand Down Expand Up @@ -220,28 +224,19 @@ def select_multikrum(
If `num_nodes_to_select` is set to 1, Multi-Krum reduces to classical Krum
and only a single RecordDict is selected.
"""
# Construct list of ArrayRecord objects from replies
record_key = list(contents[0].array_records.keys())[0]
# Recall aggregate_train first ensures replies only contain one ArrayRecord
array_records = [cast(ArrayRecord, reply[record_key]) for reply in contents]
distance_matrix = compute_distances(array_records)

# For each node, take the n-f-2 closest parameters vectors
num_closest = max(1, len(array_records) - num_malicious_nodes - 2)
closest_indices = []
for distance in distance_matrix:
closest_indices.append(
np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
)
# Get closest indices for each node efficiently using numpy
sorted_indices = np.argsort(distance_matrix, axis=1)
# Remove self (index 0) and get the n-f-2 closest (index 1 to num_closest)
closest_indices = sorted_indices[:, 1 : num_closest + 1]

# Compute the score for each node, that is the sum of the distances
# of the n-f-2 closest parameters vectors
scores = [
np.sum(distance_matrix[i, closest_indices[i]])
for i in range(len(distance_matrix))
]
# Efficiently compute scores as sum of closest distances (vectorized)
scores = np.take_along_axis(distance_matrix, closest_indices, axis=1).sum(axis=1)

# Choose the num_nodes_to_select lowest-scoring nodes (MultiKrum)
# and return their updates
best_indices = np.argsort(scores)[:num_nodes_to_select]
return [contents[i] for i in best_indices]