Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,10 @@ async def run_evaluator(
)

if experiment:
summary = experiment.summarize(summarize_scores=evaluator.summarize_scores)
summary = experiment.summarize(
summarize_scores=evaluator.summarize_scores,
comparison_experiment_id=evaluator.base_experiment_id,
)
else:
summary = build_local_summary(evaluator, results)

Expand Down
6 changes: 6 additions & 0 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3969,6 +3969,12 @@ def summarize(
if base_experiment:
comparison_experiment_id = base_experiment.id
comparison_experiment_name = base_experiment.name
else:
try:
comparison_experiment = state.api_conn().get_json(f"v1/experiment/{comparison_experiment_id}")
comparison_experiment_name = comparison_experiment.get("name")
except Exception:
pass

try:
summary_items = state.api_conn().get_json(
Expand Down
58 changes: 57 additions & 1 deletion py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.util
import re
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from braintrust.logger import BraintrustState
Expand Down Expand Up @@ -78,6 +78,62 @@ def exact_match(input_value, output, expected):
assert result.summary.scores["exact_match"].score == 1.0


@pytest.mark.asyncio
async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
def exact_match(input_value, output, expected):
return 1.0 if output == expected else 0.0

evaluator = Evaluator(
project_name="test-project",
eval_name="test-evaluator",
data=[EvalCase(input=1, expected=1)],
task=lambda input_value: input_value,
scores=[exact_match],
experiment_name=None,
metadata=None,
base_experiment_id="base-exp-id",
)

exp = init_test_exp("test-evaluator", "test-project")
expected_summary = MagicMock()
exp.summarize = MagicMock(return_value=expected_summary)

result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])

assert result.summary is expected_summary
exp.summarize.assert_called_once_with(
summarize_scores=True,
comparison_experiment_id="base-exp-id",
)


def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
exp = init_test_exp("test-evaluator", "test-project")
mock_conn = MagicMock()

def get_json(path, args=None):
if path == "v1/experiment/base-exp-id":
return {"name": "base-exp"}
if path == "experiment-comparison2":
return {"scores": {}, "metrics": {}}
raise AssertionError(f"Unexpected get_json call: {path}, {args}")

mock_conn.get_json.side_effect = get_json

with patch.object(exp.state, "api_conn", return_value=mock_conn):
summary = exp.summarize(comparison_experiment_id="base-exp-id")

assert summary.comparison_experiment_name == "base-exp"
mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id")
mock_conn.get_json.assert_any_call(
"experiment-comparison2",
args={
"experiment_id": "test-evaluator",
"base_experiment_id": "base-exp-id",
},
)


@pytest.mark.asyncio
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
async def test_run_evaluator_exposes_validated_parameter_values_to_hooks():
Expand Down
Loading