Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dspy.predict.aggregation import majority
from dspy.predict.helpers import majority_k
from dspy.predict.best_of_n import BestOfN
from dspy.predict.chain_of_thought import ChainOfThought
from dspy.predict.code_act import CodeAct
Expand All @@ -12,6 +13,7 @@

__all__ = [
"majority",
"majority_k",
"BestOfN",
"ChainOfThought",
"CodeAct",
Expand Down
30 changes: 30 additions & 0 deletions dspy/predict/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Callable, Union
from dspy.predict.aggregation import majority
from dspy.primitives.prediction import Prediction, Completions

def majority_k(
predict_fn: Union[Callable[..., Any], Prediction, Completions, list],
k: int = 5,
**majority_kwargs
) -> Any:
"""
Minimal wrapper running predict_fn k times and returning majority().

Args:
predict_fn: A callable that takes keyword arguments and returns a value,
or an existing Prediction/Completions/list to pass to majority().
k: Number of times to run the predictor (only used if predict_fn is callable).
**majority_kwargs: Additional arguments to pass to majority().

Returns:
If predict_fn is callable: A callable that runs predict_fn k times and returns the majority result.
Otherwise: The result of majority(predict_fn, **majority_kwargs).
"""
if not callable(predict_fn):
return majority(predict_fn, **majority_kwargs)

def wrapped(**inputs: Any) -> Any:
preds = [predict_fn(**inputs) for _ in range(k)]
return majority(preds, **majority_kwargs)

return wrapped
53 changes: 53 additions & 0 deletions tests/predict/test_majority_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from unittest.mock import MagicMock, call
from dspy.predict.helpers import majority_k
from dspy.primitives.prediction import Prediction, Completions

def test_majority_k_with_callable(monkeypatch):
# Test with a callable predictor
mock_predictor = MagicMock(return_value={"answer": "42"})
mock_majority = MagicMock(return_value="mock_result")
monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority)

wrapped = majority_k(mock_predictor, k=3)

# Call the wrapped function
result = wrapped(question="test", other_param=123)

# Verify the predictor was called 3 times with the same args
assert mock_predictor.call_count == 3
mock_predictor.assert_has_calls([call(question="test", other_param=123)] * 3)

# Verify majority was called once with a list of k predictions
mock_majority.assert_called_once()
predictions = mock_majority.call_args[0][0]
assert len(predictions) == 3
assert all(p == {"answer": "42"} for p in predictions)
assert result == "mock_result"

def test_majority_k_with_existing_completions(monkeypatch):
# Test with existing completions (non-callable)
completions = [{"answer": "2"}, {"answer": "2"}, {"answer": "3"}]
mock_majority = MagicMock(return_value="mock_result")
monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority)

result = majority_k(completions, field="answer")

# Should directly call majority() with the completions
mock_majority.assert_called_once_with(completions, field="answer")
assert result == "mock_result"

def test_majority_k_with_kwargs(monkeypatch):
# Test that kwargs are passed to majority()
mock_majority = MagicMock(return_value="mock_result")
monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority)

predictor = lambda x: {"answer": x}
wrapped = majority_k(predictor, k=2, field="answer", normalize=lambda x: x)

result = wrapped(x="test")
kwargs = mock_majority.call_args[1]

assert kwargs["field"] == "answer"
assert callable(kwargs["normalize"])
assert result == "mock_result"