# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, Literal, Tuple
import numpy as np
import numpy.typing as npt
import pandas as pd
from balance import util as balance_util
from balance.weighting_methods import (
adjust_null as balance_adjust_null,
cbps as balance_cbps,
ipw as balance_ipw,
poststratify as balance_poststratify,
rake as balance_rake,
)
from pandas.api.types import is_bool_dtype, is_numeric_dtype
logger: logging.Logger = logging.getLogger(__package__)
BALANCE_WEIGHTING_METHODS: Dict[str, Callable[..., Any]] = {
"ipw": balance_ipw.ipw,
"cbps": balance_cbps.cbps,
"null": balance_adjust_null.adjust_null,
"poststratify": balance_poststratify.poststratify,
"rake": balance_rake.rake,
}
def _validate_limit(limit: float | int | None, n_weights: int) -> float | None:
"""Validate and adjust a percentile limit for use when clipping percentiles.
This function prepares percentile limits for winsorization by:
1. Validating that finite limits are within valid bounds (0-1)
2. Adding a small adjustment to finite, non-zero limits to ensure at least
one value gets winsorized at the boundary percentile
The adjustment prevents edge cases where the exact percentile value might not
trigger winsorization due to floating-point precision or discrete data distributions.
The adjustment is the minimum of (2/n_weights, limit/10), capped at 1.0.
Special cases:
- None: Returns None unchanged (no winsorization on this side)
- 0: Returns 0.0 (no winsorization on this side)
- Non-finite (inf): Returns as-is without validation or adjustment
Args:
limit (Union[float, int, None]): The percentile limit to validate.
For finite values, should be between 0 and 1.
n_weights (int): The number of weights in the dataset. Used to calculate
an adjustment factor that scales inversely with sample size.
Returns:
Union[float, None]: The validated and adjusted limit, or None if the
input limit was None.
Raises:
ValueError: If the limit is finite and not between 0 and 1.
"""
if limit is None:
return None
limit = float(limit)
if limit == 0:
return 0.0
# Check for non-finite values before validating range
if not np.isfinite(limit):
return limit
# Validate range only for finite values
if limit < 0 or limit > 1:
raise ValueError("Percentile limits must be between 0 and 1")
# Apply adjustment for finite values within valid range
extra = min(2.0 / max(n_weights, 1), limit / 10.0)
adjusted = min(limit + extra, 1.0)
return adjusted
def _quantile_with_method(
data: pd.Series | npt.NDArray, q: float, method: str
) -> float:
"""Compute a quantile with explicit method selection.
Args:
data: Array-like input that can be consumed by ``np.asarray``.
q: Quantile to compute in the inclusive range ``[0, 1]``.
method: Quantile algorithm to use. Typical values include
``"higher"``/``"lower"`` (step functions) and ``"linear"``
(default continuous interpolation). The value is forwarded to
``np.quantile`` via the ``method`` argument when available.
Returns:
``float``: The computed quantile value converted to ``float``.
Examples:
.. code-block:: python
_quantile_with_method([1, 2, 3, 4], 0.25, "higher")
2.0
_quantile_with_method([1, 2, 3, 4], 0.75, "lower")
3.0
"""
array_data = np.asarray(data, dtype=np.float64)
try:
return float(np.quantile(array_data, q, method=method))
except TypeError:
# Older NumPy versions (<1.22) use the ``interpolation`` kwarg.
return float(np.quantile(array_data, q, interpolation=method))
[docs]
def trim_weights(
weights: pd.Series | npt.NDArray | list[float] | tuple[float, ...],
weight_trimming_mean_ratio: float | int | None = None,
weight_trimming_percentile: float | Tuple[float, float] | None = None,
verbose: bool = False,
keep_sum_of_weights: bool = True,
target_sum_weights: float | int | np.floating | None = None,
) -> pd.Series:
"""Trim extreme weights using mean ratio clipping or percentile-based winsorization.
The user cannot supply both weight_trimming_mean_ratio and weight_trimming_percentile.
If neither is supplied, the original weights are returned unchanged.
**Mean Ratio Trimming (weight_trimming_mean_ratio)**:
When specified, weights are clipped from above at mean(weights) * ratio, then
renormalized to preserve the original mean. This is a hard upper bound.
Note: Final weights may slightly exceed the trimming ratio due to renormalization
redistributing the clipped weight mass across all observations.
**Percentile-Based Winsorization (weight_trimming_percentile)**:
When specified, extreme weights are replaced with less extreme values using
percentile clipping. By default, winsorization affects both tails of the
distribution symmetrically, unlike mean ratio trimming which only clips from
above.
Behavior:
- Single value (e.g., 0.1): Winsorizes below 10th AND above 90th percentile
- Tuple (lower, upper): Winsorizes independently on each side
- (0.1, 0): Only winsorizes below 10th percentile
- (0, 0.1): Only winsorizes above 90th percentile
- (0.01, 0.05): Winsorizes below 1st AND above 95th percentile
Important implementation detail: Percentile limits are automatically adjusted
upward slightly (via _validate_limit) to ensure at least one value gets
winsorized at boundary percentiles. This prevents edge cases where discrete
distributions or floating-point precision might prevent winsorization at the
exact percentile value. The adjustment is min(2/n_weights, limit/10), capped at 1.0.
After trimming/winsorization, if keep_sum_of_weights=True (default), weights
are rescaled to preserve the original sum of weights. Alternatively, pass a
``target_sum_weights`` to rescale the trimmed weights so their sum matches a
desired total.
Args:
weights (pd.Series | np.ndarray | list[float] | tuple[float, ...]): Weights to trim.
Arrays and sequences will be converted to pd.Series internally.
weight_trimming_mean_ratio (float | int | None, optional): Ratio for upper bound
clipping as mean(weights) * ratio. Mutually exclusive with
weight_trimming_percentile. Defaults to None.
weight_trimming_percentile (float | tuple[float, float] | None, optional):
Percentile limits for winsorization. Value(s) must be between 0 and 1.
- Single float: Symmetric winsorization on both tails
- tuple[float, float]: (lower_percentile, upper_percentile) for
independent control of each tail
Mutually exclusive with weight_trimming_mean_ratio. Defaults to None.
verbose (bool, optional): Whether to log details about the trimming process.
Defaults to False.
keep_sum_of_weights (bool, optional): Whether to rescale weights after trimming
to preserve the original sum of weights. Defaults to True.
target_sum_weights (float | int | np.floating | None, optional): If
provided, rescale the trimmed weights so their sum equals this
target. ``None`` (default) leaves the post-trimming sum unchanged.
Raises:
TypeError: If weights is not np.array or pd.Series.
ValueError: If both weight_trimming_mean_ratio and weight_trimming_percentile
are specified, or if weight_trimming_percentile tuple has length != 2.
Returns:
pd.Series (of type float64): Trimmed weights with the same index as input
Examples:
.. code-block:: python
import pandas as pd
from balance.adjustment import trim_weights
weights = pd.Series(range(1, 101))
symmetric = trim_weights(
weights,
weight_trimming_percentile=0.01,
keep_sum_of_weights=False,
)
symmetric.equals(pd.Series(range(1, 101)).clip(3, 98).astype(float))
# True
upper = trim_weights(
weights,
weight_trimming_percentile=(0.0, 0.05),
keep_sum_of_weights=False,
)
upper.equals(pd.Series(range(1, 101)).clip(upper=94).astype(float))
# True
"""
original_name = getattr(weights, "name", None)
if isinstance(weights, pd.Series):
weights = weights.astype(np.float64)
elif isinstance(weights, (np.ndarray, list, tuple)):
weights = pd.Series(
np.asarray(weights, dtype=np.float64), dtype=np.float64, name=original_name
)
else:
raise TypeError(
"weights must be np.array, list, tuple, or pd.Series, are of type: "
f"{type(weights)}"
)
weights_index = weights.index
n_weights = len(weights)
if n_weights == 0:
return pd.Series(dtype=np.float64)
if (weight_trimming_mean_ratio is not None) and (
weight_trimming_percentile is not None
):
raise ValueError(
"Only one of weight_trimming_mean_ratio and "
"weight_trimming_percentile can be set"
)
original_mean = float(np.mean(weights))
if weight_trimming_mean_ratio is not None:
max_val = weight_trimming_mean_ratio * original_mean
percent_trimmed = weights[weights > max_val].count() / weights.count()
weights = weights.clip(upper=max_val)
if verbose:
if percent_trimmed > 0:
logger.debug("Clipping weights to %s (before renormalizing)" % max_val)
logger.debug("Clipped %s of the weights" % percent_trimmed)
else:
logger.debug("No extreme weights were trimmed")
weights = pd.Series(
np.asarray(weights, dtype=np.float64),
index=weights_index,
name=original_name,
)
elif weight_trimming_percentile is not None:
# Winsorize using percentile clipping
percentile = weight_trimming_percentile
if isinstance(percentile, (list, tuple, np.ndarray)):
if len(percentile) != 2:
raise ValueError(
"weight_trimming_percentile must be a single value or a length-2 iterable"
)
lower_limit, upper_limit = percentile
else:
lower_limit = upper_limit = percentile
# Treat 0 or None as "no winsorization" for that side before applying
# validation adjustments that guarantee at least one affected value.
normalized_limits = (
None if (lower_limit is None or lower_limit == 0) else float(lower_limit),
None if (upper_limit is None or upper_limit == 0) else float(upper_limit),
)
adjusted_limits = (
_validate_limit(normalized_limits[0], n_weights),
_validate_limit(normalized_limits[1], n_weights),
)
lower_limit_adjusted = adjusted_limits[0]
upper_limit_adjusted = adjusted_limits[1]
lower_bound = (
None
if lower_limit_adjusted is None
else _quantile_with_method(weights, lower_limit_adjusted, "higher")
)
upper_bound = (
None
if upper_limit_adjusted is None
else _quantile_with_method(weights, 1 - upper_limit_adjusted, "lower")
)
if verbose:
logger.debug(
"Winsorizing weights to %s percentile" % str(weight_trimming_percentile)
)
weights = weights.clip(lower=lower_bound, upper=upper_bound)
if keep_sum_of_weights:
weights = weights / np.mean(weights) * original_mean
if target_sum_weights is not None:
target_total = float(target_sum_weights)
current_total = float(weights.sum())
if np.isclose(current_total, 0.0):
raise ValueError("Cannot normalise weights because their sum is zero.")
weights = weights * (target_total / current_total)
weights = weights.rename(original_name)
return weights
# Sentinel for ``_reject_data_dependent_transfer`` to distinguish
# "caller did not supply transformations_effective" from "caller
# explicitly passed None" (the latter is a valid value: it means the
# fit-time effective transformations dict was filtered down to an empty
# dict and stored as None, e.g. when every user-supplied transform key
# was out of scope for the selected variables).
_UNSET_EFFECTIVE_TRANSFORMATIONS: Any = object()
def _reject_data_dependent_transfer(
transformations_origin: Any,
*,
method_name: str,
transformations_effective: Any = _UNSET_EFFECTIVE_TRANSFORMATIONS,
) -> None:
"""Guard ``predict_weights(data=...)`` against unsafe transformation replay.
Shared by ``rake._predict_weights_from_model`` and
``poststratify._predict_weights_from_model``. ``cbps`` and ``ipw``
have their own ``predict_weights(data=...)`` paths in
:class:`BalanceFrame` but do not currently persist
``transformations_origin`` or invoke this guard; they may adopt both
if they need protection against replaying fit-time data-dependent
transformations on a different scoring sample.
Raises if the fitted model's ``transformations`` argument is
data-dependent and therefore unsafe to replay across a different
scoring sample, namely:
- ``transformations='default'``: the default-transformations dispatcher
computes bin edges / kept levels from the fit-time data, so the
stored transformed cells do not generalize.
- An explicit ``dict`` whose values include direct references to
balance's known data-dependent helpers (``quantize`` or
``fct_lump``): same hazard, just user-supplied.
Offender detection runs against ``transformations_effective`` (the
filtered dict actually applied at fit time, ``model['transformations']``)
when callers supply it, so a ``quantize`` / ``fct_lump`` reference
for an out-of-scope variable that was filtered out at fit time does
not falsely block transfer scoring. ``None`` is a valid value here
(it means the filtered effective dict was empty) and yields no
rejection. The ``'default'`` check still runs against
``transformations_origin`` so the user's literal intent is what gets
rejected.
Best-effort: this guard does NOT catch indirect uses such as
``functools.partial(fct_lump, prop=0.1)``, top-level wrapper
functions, or user-defined data-dependent transformations. The
general invariant is: any callable whose output for a row depends on
other rows in the input is unsafe to replay on a different sample.
Users supplying such transformations are responsible for either (a)
wrapping them as deterministic functions of stored fit-time
parameters (e.g. a wrapper that closes over pre-computed bin edges
or kept-level lists, not over the input frame) or (b) re-fitting the
method on the scoring data.
Args:
transformations_origin: The ``transformations`` argument the user
passed at fit time, stored in ``model['transformations_origin']``.
Used for the ``'default'`` rejection.
method_name: Weighting-method name for the error messages (e.g.
``'rake'``, ``'poststratify'``).
transformations_effective: The effective ``transformations`` dict
actually applied at fit time (typically ``model['transformations']``).
Used for the data-dependent-helper rejection. Passing
``None`` is meaningful (the empty/filtered case and yields
no rejection); omitting the argument entirely falls back to
inspecting ``transformations_origin`` so direct unit tests
can supply a single dict.
Raises:
ValueError: If ``transformations_origin`` is ``'default'`` or
the effective transformations dict references
``quantize``/``fct_lump`` directly.
"""
method_label = method_name.capitalize()
fit_call_hint = f"BalanceFrame.fit(method='{method_name}')"
if transformations_origin == "default":
raise ValueError(
f"{method_label} predict_weights(data=...) is unsupported for "
"models fitted with transformations='default' because those "
"transformations are data-dependent and not replayable across "
f"new samples. {fit_call_hint} uses transformations='default' "
"out of the box; to enable transfer scoring, pass deterministic "
"transformations explicitly at fit time (a custom function that "
"closes over fit-time-computed parameters such as bin edges or "
"kept-level lists, not over the input frame) or re-fit "
f"{method_name} on the scoring data."
)
inspect = (
transformations_effective
if transformations_effective is not _UNSET_EFFECTIVE_TRANSFORMATIONS
else transformations_origin
)
if isinstance(inspect, dict):
from balance.utils.data_transformation import fct_lump, quantize
data_dependent_helpers = {quantize, fct_lump}
offenders = sorted(
{
getattr(fn, "__name__", repr(fn))
for fn in inspect.values()
if fn in data_dependent_helpers
}
)
if offenders:
raise ValueError(
f"{method_label} predict_weights(data=...) is unsupported "
"for models fitted with data-dependent transformations "
f"({', '.join(offenders)}). These recompute bins/levels "
"from the scoring data, so stored cell ratios no longer "
"line up with the transformed scoring cells. To enable "
"transfer scoring, replace each data-dependent helper with "
"a deterministic wrapper that closes over fit-time-computed "
"parameters (e.g. pre-computed bin edges or kept-level "
"lists), not over the input frame — or re-fit "
f"{method_name} on the scoring data."
)
def _find_adjustment_method(
method: Literal["cbps", "ipw", "null", "poststratify", "rake"],
WEIGHTING_METHODS: Dict[str, Callable[..., Any]] = BALANCE_WEIGHTING_METHODS,
) -> Callable[..., Any]:
"""This function translates a string method argument to the function itself.
Args:
method (Literal["cbps", "ipw", "null", "poststratify", "rake"]): method for adjustment: cbps, ipw, null, poststratify
WEIGHTING_METHODS (Dict[str, Callable]): A dict where keys are strings of function names, and the values are
the functions themselves.
Returns:
Callable: The function for adjustment
"""
if method in WEIGHTING_METHODS.keys():
adjustment_function = WEIGHTING_METHODS[method]
else:
raise ValueError(f"Unknown adjustment method: '{method}'")
return adjustment_function