Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 27, 2025

📄 13% (0.13x) speedup for aggregate_metricrecords in framework/py/flwr/serverapp/strategy/strategy_utils.py

⏱️ Runtime : 17.3 milliseconds 15.3 milliseconds (best of 139 runs)

📝 Explanation and details

The optimized code achieves a 12% speedup through two key changes in the aggregate_metricrecords function:

1. List Comprehension for Weight Extraction

  • Replaced the explicit loop with a list comprehension to extract weights: weights: list[float] = [cast(float, next(iter(record.metric_records.values()))[weighting_metric_name]) for record in records]
  • This eliminates the overhead of multiple append() calls and reduces the number of intermediate variable assignments

2. In-Place List Updates

  • For list-valued metrics, replaced the expensive list comprehension [curr + val * weight for curr, val in zip(current_list, value)] with an in-place update loop: for i, val in enumerate(value): curr_list[i] += val * weight
  • This avoids creating new list objects for each aggregation step, which is particularly beneficial when dealing with large lists or many records

The line profiler shows the most significant improvement in the list aggregation section - the original code spent 18.4% of total time creating new lists via comprehension, while the optimized version spends only 4.7% on in-place updates. The optimization is most effective for test cases with large numbers of records containing list-valued metrics, as evidenced by the performance improvements in large-scale tests with vector data.

These changes maintain the same algorithmic complexity while reducing memory allocations and function call overhead, resulting in the observed 12% performance gain.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 38 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest
from serverapp.strategy.strategy_utils import aggregate_metricrecords

# --- Begin: Minimal stubs for dependencies ---

# Minimal MetricRecord: just a dict[str, float|list[float]]
class MetricRecord(dict):
    pass

# Minimal RecordDict: has .metric_records attribute (dict[str, MetricRecord])
class RecordDict:
    def __init__(self, metric_records):
        self.metric_records = metric_records
from serverapp.strategy.strategy_utils import aggregate_metricrecords

# ------------------- UNIT TESTS -------------------

# 1. Basic Test Cases

def test_single_record_single_metric():
    # One record, one metric (besides weight)
    rec = RecordDict({"a": MetricRecord({"weight": 2.0, "accuracy": 0.8})})
    codeflash_output = aggregate_metricrecords([rec], "weight"); result = codeflash_output

def test_two_records_scalar_average():
    # Two records, scalar metric
    rec1 = RecordDict({"a": MetricRecord({"weight": 2.0, "accuracy": 0.5})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 1.0, "accuracy": 1.0})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_two_records_list_metric():
    # Two records, list-valued metric
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "loss": [1.0, 2.0]})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 3.0, "loss": [3.0, 4.0]})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_multiple_metrics():
    # Multiple metrics per record
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 0.2, "loss": 3.0})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 3.0, "acc": 0.8, "loss": 1.0})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_weighting_metric_is_ignored_in_output():
    # The weighting metric should not appear in the result
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 0.1})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 0.3})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

# 2. Edge Test Cases

def test_zero_weights_raises():
    # All weights are zero: should raise ZeroDivisionError
    rec1 = RecordDict({"a": MetricRecord({"weight": 0.0, "acc": 1.0})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 0.0, "acc": 2.0})})
    with pytest.raises(ZeroDivisionError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_negative_weights():
    # Negative weights are allowed, but result may be outside [min, max]
    rec1 = RecordDict({"a": MetricRecord({"weight": -1.0, "acc": 2.0})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 2.0, "acc": 4.0})})
    # Weighted: (-1/1)*2 + (2/1)*4 = -2 + 8 = 6
    # But actually, total_weight = 1, weights = [-1, 2]
    # So factors: [-1, 2]
    # result: -1*2 + 2*4 = -2 + 8 = 6
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output


def test_missing_weighting_metric():
    # One record missing the weighting metric
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 0.5})})
    rec2 = RecordDict({"a": MetricRecord({"acc": 0.7})})
    with pytest.raises(KeyError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_mismatched_metric_keys():
    # Some records missing a metric key: only aggregate where present
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 0.5})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 1.0, "loss": 0.7})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_list_metric_length_mismatch():
    # List metrics of different lengths: should fail with ValueError on zip
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "loss": [1.0, 2.0]})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 1.0, "loss": [3.0]})})
    # zip will silently truncate, so result["loss"] will be length 1
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_non_float_weights():
    # Weighting metric is int, not float
    rec1 = RecordDict({"a": MetricRecord({"weight": 1, "acc": 0.5})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 3, "acc": 1.0})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_non_numeric_metric_raises():
    # Metric values that are not numeric should raise TypeError
    rec1 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": "bad"})})
    rec2 = RecordDict({"a": MetricRecord({"weight": 1.0, "acc": 1.0})})
    with pytest.raises(TypeError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_multiple_metricrecords_per_record():
    # Multiple MetricRecords in one RecordDict: all are aggregated
    rec1 = RecordDict({
        "a": MetricRecord({"weight": 1.0, "acc": 0.2}),
        "b": MetricRecord({"weight": 1.0, "acc": 0.4}),
    })
    rec2 = RecordDict({
        "a": MetricRecord({"weight": 2.0, "acc": 0.8}),
        "b": MetricRecord({"weight": 2.0, "acc": 0.6}),
    })
    # The implementation only considers the first MetricRecord for weighting,
    # but aggregates all metrics in all MetricRecords.
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

# 3. Large Scale Test Cases

def test_large_number_of_records_scalar():
    # 1000 records, scalar metric
    n = 1000
    records = []
    for i in range(n):
        # weights: 1..1000, metric: i/1000
        records.append(
            RecordDict({"a": MetricRecord({"weight": float(i + 1), "acc": float(i) / n})})
        )
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
    # Weighted average: sum(w_i * x_i) / sum(w_i)
    total_weight = sum(i + 1 for i in range(n))
    expected = sum((i + 1) * (i / n) for i in range(n)) / total_weight

def test_large_number_of_records_list():
    # 500 records, list metric of length 3
    n = 500
    records = []
    for i in range(n):
        records.append(
            RecordDict({"a": MetricRecord({
                "weight": 1.0,
                "vec": [i, i+1, i+2]
            })})
        )
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
    # Each coordinate: average of i, i+1, i+2 over 0..499
    expected = [
        sum(i for i in range(n)) / n,
        sum(i+1 for i in range(n)) / n,
        sum(i+2 for i in range(n)) / n
    ]

def test_large_number_of_metrics_per_record():
    # 100 records, each with 10 metrics
    n = 100
    m = 10
    records = []
    for i in range(n):
        metrics = {"weight": 1.0}
        for j in range(m):
            metrics[f"metric{j}"] = float(i + j)
        records.append(RecordDict({"a": MetricRecord(metrics)}))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
    # Each metric: average of i+j over i=0..99
    for j in range(m):
        expected = sum(i + j for i in range(n)) / n

def test_large_scale_list_metric_length():
    # 100 records, each with a list metric of length 100
    n = 100
    l = 100
    records = []
    for i in range(n):
        metrics = {"weight": 1.0, "vec": [float(i)] * l}
        records.append(RecordDict({"a": MetricRecord(metrics)}))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
from serverapp.strategy.strategy_utils import aggregate_metricrecords

# --- Begin: Minimal stubs for MetricRecord and RecordDict ---

class MetricRecord(dict):
    """A dict subclass representing a metric record."""

class RecordDict:
    """A container for metric records, with a .metric_records attribute."""
    def __init__(self, metric_records):
        # metric_records: dict[str, MetricRecord]
        self.metric_records = metric_records
from serverapp.strategy.strategy_utils import aggregate_metricrecords

# --- Begin: Unit Tests ---

# 1. Basic Test Cases

def test_single_record_single_metric():
    # One record, one metric, weight is 1
    rec = RecordDict({"rec": MetricRecord({"weight": 1.0, "accuracy": 0.8})})
    codeflash_output = aggregate_metricrecords([rec], "weight"); result = codeflash_output

def test_two_records_simple_average():
    # Two records, same metric, weights 1 and 1
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "accuracy": 0.5})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "accuracy": 0.9})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_two_records_different_weights():
    # Two records, weights 2 and 1
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 2.0, "accuracy": 0.5})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "accuracy": 0.8})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_multiple_metrics():
    # Two records, two metrics, weights 1 and 3
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "accuracy": 0.2, "loss": 2.0})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 3.0, "accuracy": 0.8, "loss": 1.0})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_list_metric():
    # Two records, metric is a list
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "vector": [1.0, 2.0]})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "vector": [3.0, 4.0]})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_list_metric_different_weights():
    # Two records, metric is a list, weights 1 and 3
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "vector": [1.0, 2.0]})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 3.0, "vector": [5.0, 7.0]})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_multiple_metricrecords_per_record():
    # Each RecordDict has multiple MetricRecords, but only the first one is used for weighting
    rec1 = RecordDict({
        "rec1a": MetricRecord({"weight": 2.0, "accuracy": 0.4}),
        "rec1b": MetricRecord({"loss": 1.0})
    })
    rec2 = RecordDict({
        "rec2a": MetricRecord({"weight": 1.0, "accuracy": 0.7}),
        "rec2b": MetricRecord({"loss": 2.0})
    })
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

# 2. Edge Test Cases

def test_zero_weights():
    # All weights are zero -- should raise ZeroDivisionError
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 0.0, "accuracy": 0.7})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 0.0, "accuracy": 0.9})})
    with pytest.raises(ZeroDivisionError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_negative_weights():
    # Negative weights are allowed mathematically
    rec1 = RecordDict({"rec1": MetricRecord({"weight": -1.0, "accuracy": 0.2})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 2.0, "accuracy": 0.8})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_weighting_metric_missing():
    # If weighting_metric_name is missing in a record, should raise KeyError
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "accuracy": 0.5})})
    rec2 = RecordDict({"rec2": MetricRecord({"accuracy": 0.7})})
    with pytest.raises(KeyError):
        aggregate_metricrecords([rec1, rec2], "weight")


def test_empty_metric_records_dict():
    # RecordDict with empty metric_records dict, should raise StopIteration
    rec1 = RecordDict({})
    with pytest.raises(StopIteration):
        aggregate_metricrecords([rec1], "weight")

def test_metric_is_list_of_different_lengths():
    # If lists are of different lengths, zip will truncate to shortest
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "vector": [1.0, 2.0, 3.0]})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "vector": [4.0, 5.0]})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

def test_metric_is_non_numeric():
    # If metric value is not numeric, should raise TypeError on multiplication
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "accuracy": "high"})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "accuracy": 0.7})})
    with pytest.raises(TypeError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_metric_is_none():
    # If metric value is None, should raise TypeError on multiplication
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "accuracy": None})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "accuracy": 0.7})})
    with pytest.raises(TypeError):
        aggregate_metricrecords([rec1, rec2], "weight")

def test_weighting_metric_is_list():
    # If weighting metric is a list, should raise TypeError on float conversion
    rec1 = RecordDict({"rec1": MetricRecord({"weight": [1.0], "accuracy": 0.5})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": [2.0], "accuracy": 0.7})})
    with pytest.raises(TypeError):
        aggregate_metricrecords([rec1, rec2], "weight")


def test_metric_key_is_same_as_weighting_metric():
    # If a metric key is the same as the weighting metric, it should be excluded
    rec1 = RecordDict({"rec1": MetricRecord({"weight": 1.0, "weight": 0.5, "accuracy": 0.9})})
    rec2 = RecordDict({"rec2": MetricRecord({"weight": 1.0, "weight": 0.7, "accuracy": 0.8})})
    codeflash_output = aggregate_metricrecords([rec1, rec2], "weight"); result = codeflash_output

# 3. Large Scale Test Cases

def test_large_number_of_records():
    # 1000 records, all with weight 1, accuracy from 0 to 0.999
    records = []
    for i in range(1000):
        rec = RecordDict({f"rec{i}": MetricRecord({"weight": 1.0, "accuracy": i / 999.0})})
        records.append(rec)
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output

def test_large_number_of_metrics():
    # 10 records, each with 100 metrics, all weights 1
    metrics = [f"m{i}" for i in range(100)]
    records = []
    for j in range(10):
        mr = {f"rec{j}": MetricRecord({"weight": 1.0, **{m: float(j) for m in metrics}})}
        records.append(RecordDict(mr))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
    # Each metric should be the average of j=0..9, which is 4.5
    for m in metrics:
        pass

def test_large_list_metrics():
    # 10 records, each with a metric that's a list of length 100
    records = []
    for j in range(10):
        vector = [float(j) for _ in range(100)]
        mr = {f"rec{j}": MetricRecord({"weight": 1.0, "vector": vector})}
        records.append(RecordDict(mr))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output

def test_large_number_of_records_and_metrics():
    # 100 records, each with 10 metrics, all weights 1
    metrics = [f"m{i}" for i in range(10)]
    records = []
    for j in range(100):
        mr = {f"rec{j}": MetricRecord({"weight": 1.0, **{m: float(j) for m in metrics}})}
        records.append(RecordDict(mr))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
    # Each metric should be the average of j=0..99, which is 49.5
    for m in metrics:
        pass

def test_large_scale_mixed_list_and_scalar_metrics():
    # 100 records, each with a scalar and a list metric
    records = []
    for j in range(100):
        mr = {f"rec{j}": MetricRecord({"weight": 1.0, "accuracy": float(j), "vector": [float(j)]*10})}
        records.append(RecordDict(mr))
    codeflash_output = aggregate_metricrecords(records, "weight"); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from flwr.common.record.metricrecord import MetricRecord
from flwr.common.record.recorddict import RecordSet
from serverapp.strategy.strategy_utils import aggregate_metricrecords

def test_aggregate_metricrecords():
    aggregate_metricrecords([RecordSet(records={}, parameters_records=None, metrics_records={'': MetricRecord(metric_dict={}, keep_input=False)}, configs_records={})], '')

To edit these changes git checkout codeflash/optimize-aggregate_metricrecords-mh9hz1xx and push.

Codeflash

The optimized code achieves a 12% speedup through two key changes in the `aggregate_metricrecords` function:

**1. List Comprehension for Weight Extraction**
- Replaced the explicit loop with a list comprehension to extract weights: `weights: list[float] = [cast(float, next(iter(record.metric_records.values()))[weighting_metric_name]) for record in records]`
- This eliminates the overhead of multiple `append()` calls and reduces the number of intermediate variable assignments

**2. In-Place List Updates**
- For list-valued metrics, replaced the expensive list comprehension `[curr + val * weight for curr, val in zip(current_list, value)]` with an in-place update loop: `for i, val in enumerate(value): curr_list[i] += val * weight`
- This avoids creating new list objects for each aggregation step, which is particularly beneficial when dealing with large lists or many records

The line profiler shows the most significant improvement in the list aggregation section - the original code spent 18.4% of total time creating new lists via comprehension, while the optimized version spends only 4.7% on in-place updates. The optimization is most effective for test cases with large numbers of records containing list-valued metrics, as evidenced by the performance improvements in large-scale tests with vector data.

These changes maintain the same algorithmic complexity while reducing memory allocations and function call overhead, resulting in the observed 12% performance gain.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 27, 2025 18:55
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant