diff --git a/py/src/braintrust/integrations/langchain/callbacks.py b/py/src/braintrust/integrations/langchain/callbacks.py index 96646604..2cb9b37d 100644 --- a/py/src/braintrust/integrations/langchain/callbacks.py +++ b/py/src/braintrust/integrations/langchain/callbacks.py @@ -617,24 +617,6 @@ def _get_model_name_from_response(response: LLMResult) -> str | None: return model_name -def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool: - # LangChain provider packages use different cache-token conventions: - # - OpenAI-style responses report cache reads as a subset of input_tokens. - # - Anthropic-style responses report cache reads/creation separately from input_tokens. - # - # Avoid provider-name checks here so any LangChain integration using the same - # "separate cache tokens" schema gets normalized, while providers that only - # expose cache_read as input-token detail do not get double-counted. - return any( - key in input_token_details - for key in ( - "cache_creation", - "ephemeral_5m_input_tokens", - "ephemeral_1h_input_tokens", - ) - ) - - def _get_metrics_from_response(response: LLMResult): metrics = {} @@ -685,15 +667,16 @@ def _get_metrics_from_response(response: LLMResult): completion_tokens = metrics.get("completion_tokens") total_tokens = metrics.get("total_tokens") if prompt_tokens is not None and completion_tokens is not None: - if ( - cache_tokens - and total_tokens == prompt_tokens + completion_tokens - and _cache_tokens_are_separate_from_input_tokens(input_token_details) - ): + # LangChain's UsageMetadata contract makes input_token_details a + # breakdown of input_tokens, so cache tokens already count toward + # the prompt total (langchain-anthropic >= 0.2.3, langchain-aws, + # langchain-openai all comply). Cache tokens exceeding the prompt + # total means the integration reported uncached input only — fold + # cache tokens back in so prompt/total stay internally consistent. + if cache_tokens > prompt_tokens and total_tokens == prompt_tokens + completion_tokens: prompt_tokens += cache_tokens metrics["prompt_tokens"] = prompt_tokens - if total_tokens is not None: - metrics["total_tokens"] = total_tokens + cache_tokens + metrics["total_tokens"] = total_tokens + cache_tokens metrics["tokens"] = prompt_tokens + completion_tokens if not metrics or not any(metrics.values()): diff --git a/py/src/braintrust/integrations/langchain/test_callbacks.py b/py/src/braintrust/integrations/langchain/test_callbacks.py index e05a5775..44d51338 100644 --- a/py/src/braintrust/integrations/langchain/test_callbacks.py +++ b/py/src/braintrust/integrations/langchain/test_callbacks.py @@ -1098,6 +1098,12 @@ def test_prompt_caching_tokens(logger_memory_logger): assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"] + # langchain-anthropic already folds cache read/creation tokens into + # usage_metadata input_tokens; the callback must not add them again. + assert res.usage_metadata is not None + assert first_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"] + assert first_metrics["total_tokens"] == res.usage_metadata["total_tokens"] + second_metrics = None for attempt in range(3): res = model.invoke( @@ -1134,6 +1140,10 @@ def test_prompt_caching_tokens(logger_memory_logger): assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"] assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"] + assert res.usage_metadata is not None + assert second_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"] + assert second_metrics["total_tokens"] == res.usage_metadata["total_tokens"] + @pytest.mark.vcr def test_image_input(logger_memory_logger):