diff --git a/pyproject.toml b/pyproject.toml index fed3d37359e4..e101e225799a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -388,7 +388,6 @@ ignore_missing_imports = true # - python3 -m tools.mypy_helpers.find_easiest_modules [[tool.mypy.overrides]] module = [ - "sentry.snuba.metrics.query_builder", "sentry.testutils.cases", ] disable_error_code = [ diff --git a/src/sentry/snuba/metrics/query_builder.py b/src/sentry/snuba/metrics/query_builder.py index 066d553cc6f3..03d6ea28828d 100644 --- a/src/sentry/snuba/metrics/query_builder.py +++ b/src/sentry/snuba/metrics/query_builder.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from enum import Enum -from typing import Any, TypedDict, overload +from typing import Any, TypedDict, cast, overload import sentry_sdk from snuba_sdk import ( @@ -66,6 +66,7 @@ DATASET_COLUMNS, FIELD_ALIAS_MAPPINGS, FILTERABLE_TAGS, + MetricOperationType, NON_RESOLVABLE_TAG_VALUES, TS_COL_GROUP, DerivedMetricParseException, @@ -95,6 +96,16 @@ QUERY_PROJECT_LIMIT = 10 +MetricExpressionParams = Mapping[str, str | int | float] +MetricFieldKey = tuple[MetricOperationType | None, str, str] + + +def _as_metric_expression_params( + params: dict[str, None | str | int | float | Sequence[tuple[str | int, ...]]] | None, +) -> MetricExpressionParams | None: + return cast(MetricExpressionParams | None, params) + + def _strip_project_id(condition: Condition) -> Condition | None: if isinstance(condition, BooleanCondition): new_boolean_condition = BooleanCondition() @@ -128,7 +139,7 @@ def parse_public_field(field: str) -> MetricField: matches = PUBLIC_EXPRESSION_REGEX.match(field) if matches is not None: - operation = matches[1] + operation = cast(MetricOperationType, matches[1]) metric_name = matches[2] else: operation = None @@ -747,11 +758,14 @@ def translate_meta_results( continue elif alias_type == AliasMetaType.GROUP_BY_METRIC_FIELD: metric_groupby_field = alias_to_metric_group_by_field[record["name"]] + if not isinstance(metric_groupby_field.field, MetricField): + raise InvalidParams(f"Field {record['name']} was not a metric group-by field") defined_parent_meta_type = get_metric_object_from_metric_field( metric_groupby_field.field ).get_meta_type() - record["type"] = defined_parent_meta_type + if defined_parent_meta_type is not None: + record["type"] = defined_parent_meta_type elif alias_type == AliasMetaType.TAG: record["type"] = "string" elif alias_type == AliasMetaType.DATASET_COLUMN or alias_type == AliasMetaType.TIME_COLUMN: @@ -824,16 +838,27 @@ def generate_snql_for_action_by_fields( the snql generation starts to diverge significantly. """ - is_group_by = isinstance(metric_action_by_field, MetricGroupByField) - is_order_by = isinstance(metric_action_by_field, MetricOrderByField) + group_by_field = ( + metric_action_by_field + if isinstance(metric_action_by_field, MetricGroupByField) + else None + ) + order_by_field = ( + metric_action_by_field + if isinstance(metric_action_by_field, MetricOrderByField) + else None + ) + is_group_by = group_by_field is not None + is_order_by = order_by_field is not None if not is_group_by and not is_order_by: raise InvalidParams("The metric action must either be an order by or group by.") if isinstance(metric_action_by_field.field, str): # This transformation is currently supported only for group by because OrderBy doesn't support the Function type. if is_group_by and metric_action_by_field.field == "transaction": + assert group_by_field is not None return transform_null_transaction_to_unparameterized( - use_case_id, org_id, metric_action_by_field.alias + use_case_id, org_id, group_by_field.alias ) # Handles the case when we are trying to group or order by `project` for example, but we want @@ -856,7 +881,7 @@ def generate_snql_for_action_by_fields( exp = ( AliasedExpression( exp=Column(name=column_name), - alias=metric_action_by_field.alias, + alias=group_by_field.alias if group_by_field is not None else "", ) if is_group_by and not is_column else Column(name=column_name) @@ -865,7 +890,8 @@ def generate_snql_for_action_by_fields( if is_order_by: # We return a list in order to use the "extend" method and reduce the number of changes across # the codebase. - exp = [OrderBy(exp=exp, direction=metric_action_by_field.direction)] + assert order_by_field is not None + exp = [OrderBy(exp=exp, direction=order_by_field.direction)] return exp elif isinstance(metric_action_by_field.field, MetricField): @@ -878,16 +904,17 @@ def generate_snql_for_action_by_fields( return metric_expression.generate_groupby_statements( use_case_id=use_case_id, alias=metric_action_by_field.field.alias, - params=metric_action_by_field.field.params, + params=_as_metric_expression_params(metric_action_by_field.field.params), projects=projects, )[0] elif is_order_by: + assert order_by_field is not None return metric_expression.generate_orderby_clause( use_case_id=use_case_id, alias=metric_action_by_field.field.alias, - params=metric_action_by_field.field.params, + params=_as_metric_expression_params(metric_action_by_field.field.params), projects=projects, - direction=metric_action_by_field.direction, + direction=order_by_field.direction, ) else: raise NotImplementedError( @@ -925,14 +952,18 @@ def _build_where(self) -> list[BooleanCondition | Condition]: Condition( lhs=metric_expression.generate_where_statements( use_case_id=self._use_case_id, - params=condition.lhs.params, + params=_as_metric_expression_params(condition.lhs.params), projects=self._projects, alias=condition.lhs.alias, )[0], op=condition.op, rhs=( resolve_tag_value(self._use_case_id, self._org_id, condition.rhs) - if require_rhs_condition_resolution(condition.lhs.op) + if ( + condition.lhs.op is not None + and require_rhs_condition_resolution(condition.lhs.op) + and isinstance(condition.rhs, str) + ) else condition.rhs ), ) @@ -1069,6 +1100,8 @@ def __build_totals_and_series_queries( series_limit = self._metrics_query.max_limit if self._use_case_id in [UseCaseID.TRANSACTIONS, UseCaseID.SPANS]: + if self._metrics_query.interval is None: + raise InvalidParams("An interval is required for discover metrics series queries") time_groupby_column = self.__generate_time_groupby_column_for_discover_queries( self._metrics_query.interval ) @@ -1100,10 +1133,10 @@ def __generate_time_groupby_column_for_discover_queries(interval: int) -> Functi def __update_query_dicts_with_component_entities( self, component_entities: dict[MetricEntity, Sequence[str]], - metric_mri_to_obj_dict: dict[tuple[str | None, str, str], MetricExpressionBase], - fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]], + metric_mri_to_obj_dict: dict[MetricFieldKey, MetricExpressionBase], + fields_in_entities: dict[MetricEntity, list[MetricFieldKey]], parent_alias, - ) -> dict[tuple[str | None, str, str], MetricExpressionBase]: + ) -> dict[MetricFieldKey, MetricExpressionBase]: # At this point in time, we are only supporting raw metrics in the metrics attribute of # any instance of DerivedMetric, and so in this case the op will always be None # ToDo(ahmed): In future PR, we might want to allow for dependency metrics to also have an @@ -1128,8 +1161,8 @@ def __update_query_dicts_with_component_entities( return metric_mri_to_obj_dict def get_snuba_queries(self): - metric_mri_to_obj_dict: dict[tuple[str | None, str, str], MetricExpressionBase] = {} - fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]] = {} + metric_mri_to_obj_dict: dict[MetricFieldKey, MetricExpressionBase] = {} + fields_in_entities: dict[MetricEntity, list[MetricFieldKey]] = {} for select_field in self._metrics_query.select: metric_field_obj = metric_object_factory(select_field.op, select_field.metric_mri) @@ -1148,12 +1181,17 @@ def get_snuba_queries(self): projects=self._projects, use_case_id=self._use_case_id ) if isinstance(component_entities, dict): + cleaned_component_entities: dict[MetricEntity, Sequence[str]] = {} + for entity, metric_mris in component_entities.items(): + if entity is None: + raise DerivedMetricParseException("Entity parsed is in incorrect format") + cleaned_component_entities[entity] = metric_mris # In this case, component_entities is a dictionary with entity keys and # lists of metric_mris as values representing all the entities and # metric_mris combination that this metric_object is composed of, or rather # the instances of SingleEntityDerivedMetric that it is composed of metric_mri_to_obj_dict = self.__update_query_dicts_with_component_entities( - component_entities=component_entities, + component_entities=cleaned_component_entities, metric_mri_to_obj_dict=metric_mri_to_obj_dict, fields_in_entities=fields_in_entities, parent_alias=select_field.alias, @@ -1196,12 +1234,14 @@ def get_snuba_queries(self): # In order to support on demand metrics which require an interval (e.g. epm), # we want to pass the interval down via params so we can pass it to the associated snql_factory - params = {"interval": self._metrics_query.interval, **(params or {})} + params = {**(params or {})} + if self._metrics_query.interval is not None: + params["interval"] = self._metrics_query.interval select += metric_field_obj.generate_select_statements( projects=self._projects, use_case_id=self._use_case_id, alias=field[2], - params=params, + params=_as_metric_expression_params(params), ) metric_ids_set |= metric_field_obj.generate_metric_ids( self._projects, self._use_case_id @@ -1255,7 +1295,7 @@ def __init__( self, organization_id: int, metrics_query: DeprecatingMetricsQuery, - fields_in_entities: dict[MetricEntity, list[tuple[str | None, str, str]]], + fields_in_entities: dict[MetricEntity, list[MetricFieldKey]], intervals: list[datetime], results, use_case_id: UseCaseID, @@ -1436,7 +1476,7 @@ def resolve_tag_value(value: int | str | None) -> str | None: except KeyError: params = None totals[alias] = metric_obj.run_post_query_function( - totals, params=params, alias=alias + totals, params=_as_metric_expression_params(params), alias=alias ) if series is not None: @@ -1451,7 +1491,7 @@ def resolve_tag_value(value: int | str | None) -> str | None: except KeyError: params = None series[alias][idx] = metric_obj.run_post_query_function( - series, params=params, idx=idx, alias=alias + series, params=_as_metric_expression_params(params), idx=idx, alias=alias ) # Remove the extra fields added due to the constituent metrics that were added