Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ ignore_missing_imports = true
# - python3 -m tools.mypy_helpers.find_easiest_modules
[[tool.mypy.overrides]]
module = [
"sentry.snuba.metrics.query_builder",
"sentry.testutils.cases",
]
disable_error_code = [
Expand Down
88 changes: 64 additions & 24 deletions src/sentry/snuba/metrics/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, TypedDict, overload
from typing import Any, TypedDict, cast, overload

import sentry_sdk
from snuba_sdk import (
Expand Down Expand Up @@ -66,6 +66,7 @@
DATASET_COLUMNS,
FIELD_ALIAS_MAPPINGS,
FILTERABLE_TAGS,
MetricOperationType,
NON_RESOLVABLE_TAG_VALUES,
TS_COL_GROUP,
DerivedMetricParseException,
Expand Down Expand Up @@ -95,6 +96,16 @@
QUERY_PROJECT_LIMIT = 10


MetricExpressionParams = Mapping[str, str | int | float]
MetricFieldKey = tuple[MetricOperationType | None, str, str]


def _as_metric_expression_params(
params: dict[str, None | str | int | float | Sequence[tuple[str | int, ...]]] | None,
) -> MetricExpressionParams | None:
return cast(MetricExpressionParams | None, params)


def _strip_project_id(condition: Condition) -> Condition | None:
if isinstance(condition, BooleanCondition):
new_boolean_condition = BooleanCondition()
Expand Down Expand Up @@ -128,7 +139,7 @@ def parse_public_field(field: str) -> MetricField:
matches = PUBLIC_EXPRESSION_REGEX.match(field)

if matches is not None:
operation = matches[1]
operation = cast(MetricOperationType, matches[1])
metric_name = matches[2]
else:
operation = None
Expand Down Expand Up @@ -747,11 +758,14 @@ def translate_meta_results(
continue
elif alias_type == AliasMetaType.GROUP_BY_METRIC_FIELD:
metric_groupby_field = alias_to_metric_group_by_field[record["name"]]
if not isinstance(metric_groupby_field.field, MetricField):
raise InvalidParams(f"Field {record['name']} was not a metric group-by field")
defined_parent_meta_type = get_metric_object_from_metric_field(
metric_groupby_field.field
).get_meta_type()

record["type"] = defined_parent_meta_type
if defined_parent_meta_type is not None:
record["type"] = defined_parent_meta_type
elif alias_type == AliasMetaType.TAG:
record["type"] = "string"
elif alias_type == AliasMetaType.DATASET_COLUMN or alias_type == AliasMetaType.TIME_COLUMN:
Expand Down Expand Up @@ -824,16 +838,27 @@ def generate_snql_for_action_by_fields(
the snql generation starts to diverge significantly.
"""

is_group_by = isinstance(metric_action_by_field, MetricGroupByField)
is_order_by = isinstance(metric_action_by_field, MetricOrderByField)
group_by_field = (
metric_action_by_field
if isinstance(metric_action_by_field, MetricGroupByField)
else None
)
order_by_field = (
metric_action_by_field
if isinstance(metric_action_by_field, MetricOrderByField)
else None
)
is_group_by = group_by_field is not None
is_order_by = order_by_field is not None
if not is_group_by and not is_order_by:
raise InvalidParams("The metric action must either be an order by or group by.")

if isinstance(metric_action_by_field.field, str):
# This transformation is currently supported only for group by because OrderBy doesn't support the Function type.
if is_group_by and metric_action_by_field.field == "transaction":
assert group_by_field is not None
return transform_null_transaction_to_unparameterized(
use_case_id, org_id, metric_action_by_field.alias
use_case_id, org_id, group_by_field.alias
)

# Handles the case when we are trying to group or order by `project` for example, but we want
Expand All @@ -856,7 +881,7 @@ def generate_snql_for_action_by_fields(
exp = (
AliasedExpression(
exp=Column(name=column_name),
alias=metric_action_by_field.alias,
alias=group_by_field.alias if group_by_field is not None else "",
)
if is_group_by and not is_column
else Column(name=column_name)
Expand All @@ -865,7 +890,8 @@ def generate_snql_for_action_by_fields(
if is_order_by:
# We return a list in order to use the "extend" method and reduce the number of changes across
# the codebase.
exp = [OrderBy(exp=exp, direction=metric_action_by_field.direction)]
assert order_by_field is not None
exp = [OrderBy(exp=exp, direction=order_by_field.direction)]

return exp
elif isinstance(metric_action_by_field.field, MetricField):
Expand All @@ -878,16 +904,17 @@ def generate_snql_for_action_by_fields(
return metric_expression.generate_groupby_statements(
use_case_id=use_case_id,
alias=metric_action_by_field.field.alias,
params=metric_action_by_field.field.params,
params=_as_metric_expression_params(metric_action_by_field.field.params),
projects=projects,
)[0]
elif is_order_by:
assert order_by_field is not None
return metric_expression.generate_orderby_clause(
use_case_id=use_case_id,
alias=metric_action_by_field.field.alias,
params=metric_action_by_field.field.params,
params=_as_metric_expression_params(metric_action_by_field.field.params),
projects=projects,
direction=metric_action_by_field.direction,
direction=order_by_field.direction,
)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -925,14 +952,18 @@ def _build_where(self) -> list[BooleanCondition | Condition]:
Condition(
lhs=metric_expression.generate_where_statements(
use_case_id=self._use_case_id,
params=condition.lhs.params,
params=_as_metric_expression_params(condition.lhs.params),
projects=self._projects,
alias=condition.lhs.alias,
)[0],
op=condition.op,
rhs=(
resolve_tag_value(self._use_case_id, self._org_id, condition.rhs)
if require_rhs_condition_resolution(condition.lhs.op)
if (
condition.lhs.op is not None
and require_rhs_condition_resolution(condition.lhs.op)
and isinstance(condition.rhs, str)
)
else condition.rhs
),
)
Expand Down Expand Up @@ -1069,6 +1100,8 @@ def __build_totals_and_series_queries(
series_limit = self._metrics_query.max_limit

if self._use_case_id in [UseCaseID.TRANSACTIONS, UseCaseID.SPANS]:
if self._metrics_query.interval is None:
raise InvalidParams("An interval is required for discover metrics series queries")
time_groupby_column = self.__generate_time_groupby_column_for_discover_queries(
self._metrics_query.interval
)
Expand Down Expand Up @@ -1100,10 +1133,10 @@ def __generate_time_groupby_column_for_discover_queries(interval: int) -> Functi
def __update_query_dicts_with_component_entities(
self,
component_entities: dict[MetricEntity, Sequence[str]],
metric_mri_to_obj_dict: dict[tuple[str | None, str, str], MetricExpressionBase],
fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]],
metric_mri_to_obj_dict: dict[MetricFieldKey, MetricExpressionBase],
fields_in_entities: dict[MetricEntity, list[MetricFieldKey]],
parent_alias,
) -> dict[tuple[str | None, str, str], MetricExpressionBase]:
) -> dict[MetricFieldKey, MetricExpressionBase]:
# At this point in time, we are only supporting raw metrics in the metrics attribute of
# any instance of DerivedMetric, and so in this case the op will always be None
# ToDo(ahmed): In future PR, we might want to allow for dependency metrics to also have an
Expand All @@ -1128,8 +1161,8 @@ def __update_query_dicts_with_component_entities(
return metric_mri_to_obj_dict

def get_snuba_queries(self):
metric_mri_to_obj_dict: dict[tuple[str | None, str, str], MetricExpressionBase] = {}
fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]] = {}
metric_mri_to_obj_dict: dict[MetricFieldKey, MetricExpressionBase] = {}
fields_in_entities: dict[MetricEntity, list[MetricFieldKey]] = {}

for select_field in self._metrics_query.select:
metric_field_obj = metric_object_factory(select_field.op, select_field.metric_mri)
Expand All @@ -1148,12 +1181,17 @@ def get_snuba_queries(self):
projects=self._projects, use_case_id=self._use_case_id
)
if isinstance(component_entities, dict):
cleaned_component_entities: dict[MetricEntity, Sequence[str]] = {}
for entity, metric_mris in component_entities.items():
if entity is None:
raise DerivedMetricParseException("Entity parsed is in incorrect format")
cleaned_component_entities[entity] = metric_mris
# In this case, component_entities is a dictionary with entity keys and
# lists of metric_mris as values representing all the entities and
# metric_mris combination that this metric_object is composed of, or rather
# the instances of SingleEntityDerivedMetric that it is composed of
metric_mri_to_obj_dict = self.__update_query_dicts_with_component_entities(
component_entities=component_entities,
component_entities=cleaned_component_entities,
metric_mri_to_obj_dict=metric_mri_to_obj_dict,
fields_in_entities=fields_in_entities,
parent_alias=select_field.alias,
Expand Down Expand Up @@ -1196,12 +1234,14 @@ def get_snuba_queries(self):

# In order to support on demand metrics which require an interval (e.g. epm),
# we want to pass the interval down via params so we can pass it to the associated snql_factory
params = {"interval": self._metrics_query.interval, **(params or {})}
params = {**(params or {})}
if self._metrics_query.interval is not None:
params["interval"] = self._metrics_query.interval
select += metric_field_obj.generate_select_statements(
projects=self._projects,
use_case_id=self._use_case_id,
alias=field[2],
params=params,
params=_as_metric_expression_params(params),
)
metric_ids_set |= metric_field_obj.generate_metric_ids(
self._projects, self._use_case_id
Expand Down Expand Up @@ -1255,7 +1295,7 @@ def __init__(
self,
organization_id: int,
metrics_query: DeprecatingMetricsQuery,
fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]],
fields_in_entities: dict[MetricEntity, list[MetricFieldKey]],
intervals: list[datetime],
results,
use_case_id: UseCaseID,
Expand Down Expand Up @@ -1436,7 +1476,7 @@ def resolve_tag_value(value: int | str | None) -> str | None:
except KeyError:
params = None
totals[alias] = metric_obj.run_post_query_function(
totals, params=params, alias=alias
totals, params=_as_metric_expression_params(params), alias=alias
)

if series is not None:
Expand All @@ -1451,7 +1491,7 @@ def resolve_tag_value(value: int | str | None) -> str | None:
except KeyError:
params = None
series[alias][idx] = metric_obj.run_post_query_function(
series, params=params, idx=idx, alias=alias
series, params=_as_metric_expression_params(params), idx=idx, alias=alias
)

# Remove the extra fields added due to the constituent metrics that were added
Expand Down
Loading