diff --git a/ot/sliced.py b/ot/sliced.py index 4a0c8417b..8f724ce73 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -20,6 +20,72 @@ ) +def _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx): + """Normalize input distributions before computing sliced Wasserstein distance. + + Parameters + ---------- + X_s : array-like, shape (n_s, d) + Source samples + X_t : array-like, shape (n_t, d) + Target samples + normalize : str or None + Normalization method. One of {None, 'standard', 'minmax', 'l2'}. + normalize_mode : str + Reference for computing statistics. One of {'joint', 'source', 'target'}. + Ignored when normalize is None or 'l2'. + nx : backend + POT backend instance (from ot.backend.get_backend) + + Returns + ------- + X_s_out : array-like, shape (n_s, d) + Normalized source samples + X_t_out : array-like, shape (n_t, d) + Normalized target samples + """ + if normalize is None: + return X_s, X_t + + if normalize_mode not in ("joint", "source", "target"): + raise ValueError( + f"Invalid normalize_mode '{normalize_mode}'. " + "Expected one of: 'joint', 'source', 'target'." + ) + + if normalize == "standard": + # TODO: full implementation + # - compute mean/std using nx ops based on normalize_mode + # - apply to both X_s and X_t + # - handle zero-variance columns with warnings.warn + raise NotImplementedError( + "normalize='standard' will be implemented in a follow-up commit." + ) + + elif normalize == "minmax": + # TODO: full implementation + # - compute min/max using nx ops based on normalize_mode + # - apply to both X_s and X_t + # - handle zero-range columns with warnings.warn + raise NotImplementedError( + "normalize='minmax' will be implemented in a follow-up commit." + ) + + elif normalize == "l2": + # TODO: full implementation + # - row-wise L2 normalization (normalize_mode is ignored) + # - handle zero-norm rows with warnings.warn + raise NotImplementedError( + "normalize='l2' will be implemented in a follow-up commit." + ) + + else: + raise ValueError( + f"Invalid normalize value '{normalize}'. " + "Expected one of: None, 'standard', 'minmax', 'l2'." + ) + + def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})` @@ -76,6 +142,8 @@ def sliced_wasserstein_distance( projections=None, seed=None, log=False, + normalize=None, + normalize_mode="joint", ): r""" Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance @@ -109,6 +177,24 @@ def sliced_wasserstein_distance( Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + normalize : str or None, optional + Normalization applied to X_s and X_t before computing the distance. + Useful when features have different scales. Options: + + - ``None`` : no normalization (default, preserves existing behavior) + - ``'standard'`` : zero mean, unit variance per feature dimension + - ``'minmax'`` : scale each feature to [0, 1] + - ``'l2'`` : normalize each sample to unit L2 norm (row-wise) + + normalize_mode : str, optional + Determines which samples are used to compute normalization statistics. + Ignored when ``normalize`` is ``None`` or ``'l2'``. Options: + + - ``'joint'`` : statistics from ``concat(X_s, X_t)`` (default). + Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s). + - ``'source'`` : statistics from ``X_s`` only. Useful for drift + detection where X_s is the reference distribution. + - ``'target'`` : statistics from ``X_t`` only. Returns ------- @@ -136,6 +222,8 @@ def sliced_wasserstein_distance( nx = get_backend(X_s, X_t, a, b, projections) + X_s, X_t = _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx) + n = X_s.shape[0] m = X_t.shape[0] @@ -181,6 +269,8 @@ def max_sliced_wasserstein_distance( projections=None, seed=None, log=False, + normalize=None, + normalize_mode="joint", ): r""" Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance @@ -215,6 +305,24 @@ def max_sliced_wasserstein_distance( Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + normalize : str or None, optional + Normalization applied to X_s and X_t before computing the distance. + Useful when features have different scales. Options: + + - ``None`` : no normalization (default, preserves existing behavior) + - ``'standard'`` : zero mean, unit variance per feature dimension + - ``'minmax'`` : scale each feature to [0, 1] + - ``'l2'`` : normalize each sample to unit L2 norm (row-wise) + + normalize_mode : str, optional + Determines which samples are used to compute normalization statistics. + Ignored when ``normalize`` is ``None`` or ``'l2'``. Options: + + - ``'joint'`` : statistics from ``concat(X_s, X_t)`` (default). + Preserves symmetry: SWD(X_s, X_t) == SWD(X_t, X_s). + - ``'source'`` : statistics from ``X_s`` only. Useful for drift + detection where X_s is the reference distribution. + - ``'target'`` : statistics from ``X_t`` only. Returns ------- @@ -242,6 +350,8 @@ def max_sliced_wasserstein_distance( nx = get_backend(X_s, X_t, a, b, projections) + X_s, X_t = _normalize_inputs(X_s, X_t, normalize, normalize_mode, nx) + n = X_s.shape[0] m = X_t.shape[0]