diff --git a/dspy/predict/__init__.py b/dspy/predict/__init__.py index 1d4ddffced..8333792b48 100644 --- a/dspy/predict/__init__.py +++ b/dspy/predict/__init__.py @@ -1,4 +1,5 @@ from dspy.predict.aggregation import majority +from dspy.predict.helpers import majority_k from dspy.predict.best_of_n import BestOfN from dspy.predict.chain_of_thought import ChainOfThought from dspy.predict.code_act import CodeAct @@ -12,6 +13,7 @@ __all__ = [ "majority", + "majority_k", "BestOfN", "ChainOfThought", "CodeAct", diff --git a/dspy/predict/helpers.py b/dspy/predict/helpers.py new file mode 100644 index 0000000000..49eb263c17 --- /dev/null +++ b/dspy/predict/helpers.py @@ -0,0 +1,30 @@ +from typing import Any, Callable, Union +from dspy.predict.aggregation import majority +from dspy.primitives.prediction import Prediction, Completions + +def majority_k( + predict_fn: Union[Callable[..., Any], Prediction, Completions, list], + k: int = 5, + **majority_kwargs +) -> Any: + """ + Minimal wrapper running predict_fn k times and returning majority(). + + Args: + predict_fn: A callable that takes keyword arguments and returns a value, + or an existing Prediction/Completions/list to pass to majority(). + k: Number of times to run the predictor (only used if predict_fn is callable). + **majority_kwargs: Additional arguments to pass to majority(). + + Returns: + If predict_fn is callable: A callable that runs predict_fn k times and returns the majority result. + Otherwise: The result of majority(predict_fn, **majority_kwargs). + """ + if not callable(predict_fn): + return majority(predict_fn, **majority_kwargs) + + def wrapped(**inputs: Any) -> Any: + preds = [predict_fn(**inputs) for _ in range(k)] + return majority(preds, **majority_kwargs) + + return wrapped diff --git a/tests/predict/test_majority_k.py b/tests/predict/test_majority_k.py new file mode 100644 index 0000000000..3e7915c30a --- /dev/null +++ b/tests/predict/test_majority_k.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import MagicMock, call +from dspy.predict.helpers import majority_k +from dspy.primitives.prediction import Prediction, Completions + +def test_majority_k_with_callable(monkeypatch): + # Test with a callable predictor + mock_predictor = MagicMock(return_value={"answer": "42"}) + mock_majority = MagicMock(return_value="mock_result") + monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority) + + wrapped = majority_k(mock_predictor, k=3) + + # Call the wrapped function + result = wrapped(question="test", other_param=123) + + # Verify the predictor was called 3 times with the same args + assert mock_predictor.call_count == 3 + mock_predictor.assert_has_calls([call(question="test", other_param=123)] * 3) + + # Verify majority was called once with a list of k predictions + mock_majority.assert_called_once() + predictions = mock_majority.call_args[0][0] + assert len(predictions) == 3 + assert all(p == {"answer": "42"} for p in predictions) + assert result == "mock_result" + +def test_majority_k_with_existing_completions(monkeypatch): + # Test with existing completions (non-callable) + completions = [{"answer": "2"}, {"answer": "2"}, {"answer": "3"}] + mock_majority = MagicMock(return_value="mock_result") + monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority) + + result = majority_k(completions, field="answer") + + # Should directly call majority() with the completions + mock_majority.assert_called_once_with(completions, field="answer") + assert result == "mock_result" + +def test_majority_k_with_kwargs(monkeypatch): + # Test that kwargs are passed to majority() + mock_majority = MagicMock(return_value="mock_result") + monkeypatch.setattr("dspy.predict.helpers.majority", mock_majority) + + predictor = lambda x: {"answer": x} + wrapped = majority_k(predictor, k=2, field="answer", normalize=lambda x: x) + + result = wrapped(x="test") + kwargs = mock_majority.call_args[1] + + assert kwargs["field"] == "answer" + assert callable(kwargs["normalize"]) + assert result == "mock_result"