Source code for balance.adjustment

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2.

# pyre-unsafe

from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from typing import Callable, Dict, List, Literal, Tuple, Union

import numpy as np
import numpy.typing as npt
import pandas as pd
import scipy

from balance import util as balance_util
from balance.weighting_methods import (
    adjust_null as balance_adjust_null,
    cbps as balance_cbps,
    ipw as balance_ipw,
    poststratify as balance_poststratify,
    rake as balance_rake,
)
from pandas.api.types import is_bool_dtype, is_numeric_dtype

logger: logging.Logger = logging.getLogger(__package__)


BALANCE_WEIGHTING_METHODS = {
    "ipw": balance_ipw.ipw,
    "cbps": balance_cbps.cbps,
    "null": balance_adjust_null.adjust_null,
    "poststratify": balance_poststratify.poststratify,
    "rake": balance_rake.rake,
}


[docs] def trim_weights( weights: Union[pd.Series, npt.NDArray], # TODO: add support to more types of input weights? (e.g. list? other?) weight_trimming_mean_ratio: Union[float, int, None] = None, weight_trimming_percentile: Union[float, None] = None, verbose: bool = False, keep_sum_of_weights: bool = True, ) -> pd.Series: """Trim extreme weights. The user cannot supply both weight_trimming_mean_ratio and weight_trimming_percentile. If none are supplied, the original weights are returned. If `weight_trimming_mean_ratio` is not None, the weights are trimmed from above by mean(weights) * ratio. The weights are then normalized to have the original mean. Note that trimmed weights aren't actually bounded by trimming.ratio because the reduced weight is redistributed to arrive at the original mean. If `weight_trimming_percentile` is not None, the weights are trimmed according to the percentiles of the distribution of the weights. Note that weight_trimming_percentile by default clips both sides of the distribution, unlike trimming that only trims the weights from above. For example, `weight_trimming_percentile=0.1` trims below the 10th percentile AND above the 90th. If you only want to trim the upper side, specify `weight_trimming_percentile = (0, 0.1)`. If you only want to trim the lower side, specify `weight_trimming_percentile = (0.1, 0)`. Args: weights (Union[pd.Series, np.ndarray]): pd.Series of weights to trim. np.ndarray will be turned into pd.Series) of weights. weight_trimming_mean_ratio (Union[float, int], optional): indicating the ratio from above according to which the weights are trimmed by mean(weights) * ratio. Defaults to None. weight_trimming_percentile (Union[float], optional): if `weight_trimming_percentile` is not None, then we apply winsorization using :func:`scipy.stats.mstats.winsorize`. Ranges between 0 and 1. If a single value is passed, indicates the percentiles on both sides of the weight distribution beyond which the weights will be winsorized. If two values are passed, the first value is the lower percentiles below which winsorizing will be applied, and the second is the 1. - upper percentile above which winsorizing will be applied. For example, `weight_trimming_percentile=(0.01, 0.05)` will trim the weights with values below the 1st percentile and above the 95th percentile of the weight distribution. See also: [https://en.wikipedia.org/wiki/Winsorizing]. Defaults to None. verbose (bool, optional): whether to add to logger printout of trimming process. Defaults to False. keep_sum_of_weights (bool, optional): Set if the sum of weights after trimming should be the same as the sum of weights before trimming. Defaults to True. Raises: TypeError: If weights is not np.array or pd.Series. ValueError: If both weight_trimming_mean_ratio and weight_trimming_percentile are set. Returns: pd.Series (of type float64): Trimmed weights Examples: :: import pandas as pd from balance.adjustment import trim_weights print(trim_weights(pd.Series(range(1, 101)), weight_trimming_mean_ratio = None)) # 0 1.0 # 1 2.0 # 2 3.0 # 3 4.0 # 4 5.0 # ... # 95 96.0 # 96 97.0 # 97 98.0 # 98 99.0 # 99 100.0 # Length: 100, dtype: float64 print(trim_weights(pd.Series(range(1, 101)), weight_trimming_mean_ratio = 1.5)) # 0 1.064559 # 1 2.129117 # 2 3.193676 # 3 4.258235 # 4 5.322793 # ... # 95 80.640316 # 96 80.640316 # 97 80.640316 # 98 80.640316 # 99 80.640316 # Length: 100, dtype: float64 print(pd.DataFrame(trim_weights(pd.Series(range(1, 101)), weight_trimming_percentile=.01))) # 0 2.0 # 1 2.0 # 2 3.0 # 3 4.0 # 4 5.0 # .. ... # 95 96.0 # 96 97.0 # 97 98.0 # 98 99.0 # 99 99.0 # [100 rows x 1 columns] print(pd.DataFrame(trim_weights(pd.Series(range(1, 101)), weight_trimming_percentile=(0., .05)))) # 0 1.002979 # 1 2.005958 # 2 3.008937 # 3 4.011917 # 4 5.014896 # .. ... # 95 95.283019 # 96 95.283019 # 97 95.283019 # 98 95.283019 # 99 95.283019 """ if isinstance(weights, pd.Series): pass elif isinstance(weights, np.ndarray): weights = pd.Series(weights) else: raise TypeError( f"weights must be np.array or pd.Series, are of type: {type(weights)}" ) if (weight_trimming_mean_ratio is not None) and ( weight_trimming_percentile is not None ): raise ValueError( "Only one of weight_trimming_mean_ratio and " "weight_trimming_percentile can be set" ) original_mean = np.mean(weights) if weight_trimming_mean_ratio is not None: max_val = weight_trimming_mean_ratio * original_mean percent_trimmed = weights[weights > max_val].count() / weights.count() weights = weights.clip(upper=max_val) if verbose: if percent_trimmed > 0: logger.debug("Clipping weights to %s (before renormalizing)" % max_val) logger.debug("Clipped %s of the weights" % percent_trimmed) else: logger.debug("No extreme weights were trimmed") elif weight_trimming_percentile is not None: # Winsorize weights = scipy.stats.mstats.winsorize( weights, limits=weight_trimming_percentile, inplace=False ) if verbose: logger.debug( "Winsorizing weights to %s percentile" % str(weight_trimming_percentile) ) if keep_sum_of_weights: weights = weights / np.mean(weights) * original_mean return weights
[docs] def default_transformations( dfs: Union[Tuple[pd.DataFrame, ...], List[pd.DataFrame]] ) -> Dict[str, Callable]: """ Apply default transformations to dfs, i.e. quantize to numeric columns and fct_lump to non-numeric and boolean Args: dfs (Union[Tuple[pd.DataFrame, ...], List[pd.DataFrame]]): A list or tuple of dataframes Returns: Dict[str, Callable]: Dict of transformations """ dtypes = {} for d in dfs: dtypes.update(d.dtypes.to_dict()) transformations = {} for k, v in dtypes.items(): # Notice that in pandas: pd.api.types.is_numeric_dtype(pd.Series([True, False])) == True # Hence, we need to explicitly check that not is_bool_dtype(v) # see: https://github.com/pandas-dev/pandas/issues/38378 if (is_numeric_dtype(v)) and (not is_bool_dtype(v)): transformations[k] = balance_util.quantize else: transformations[k] = balance_util.fct_lump return transformations
[docs] def apply_transformations( dfs: Tuple[pd.DataFrame, ...], transformations: Union[Dict[str, Callable], str, None], drop: bool = True, ) -> Tuple[pd.DataFrame, ...]: """Apply the transformations specified in transformations to all of the dfs - if a column specified in `transformations` does not exist in the dataframes, it is added - if a column is not specified in `transformations`, it is dropped, unless drop==False - the dfs are concatenated together before transformations are applied, so functions like `max` are relative to the column in all dfs - Cannot transform the same variable twice, or add a variable and then transform it (i.e. the definition of the added variable should include the transformation) - if you get a cryptic error about mismatched data types, make sure your transformations are not being treated as additions because of missing columns (use `_set_warnings("DEBUG")` to check) Args: dfs (Tuple[pd.DataFrame, ...]): The DataFrames on which to operate transformations (Union[Dict[str, Callable], str, None]): Mapping from column name to function to apply. Transformations of existing columns should be specified as functions of those columns (e.g. `lambda x: x*2`), whereas additions of new columns should be specified as functions of the DataFrame (e.g. `lambda x: x.column_a + x.column_b`). drop (bool, optional): Whether to drop columns which are not specified in `transformations`. Defaults to True. Raises: NotImplementedError: When passing an unknown "transformations" argument. Returns: Tuple[pd.DataFrame, ...]: tuple of pd.DataFrames Examples: :: from balance.adjustment import apply_transformations import pandas as pd import numpy as np apply_transformations( (pd.DataFrame({'d': [1, 2, 3], 'e': [4, 5, 6]}),), {'d': lambda x: x*2, 'f': lambda x: x.d+x.e} ) # ( f d # 0 5 2 # 1 7 4 # 2 9 6,) """ # TODO: change assert to raise assert isinstance(dfs, tuple), "'dfs' argument must be a tuple of DataFrames" assert all( isinstance(x, pd.DataFrame) for x in dfs ), "'dfs' must contain DataFrames" if transformations is None: return dfs elif isinstance(transformations, str): if transformations == "default": transformations = default_transformations(dfs) else: raise NotImplementedError(f"Unknown transformations {transformations}") ns = [0] + list(np.cumsum([x.shape[0] for x in dfs])) boundaries = [(ns[i], ns[i + 1]) for i in range(0, len(ns) - 1)] indices = [x.index for x in dfs] all_data = pd.concat(dfs).reset_index(drop=True) # This is to avoid issues with trnasformations that cannot # be done on object with duplicate indecies # additions is new columns to add to data. i.e.: column names that appear in transformations # but are not present in all_data. # pyre-fixme[16]: Optional type has no attribute `columns`. additions = {k: v for k, v in transformations.items() if k not in all_data.columns} transformations = { k: v for k, v in transformations.items() if k in all_data.columns } logger.info(f"Adding the variables: {list(additions.keys())}") logger.info(f"Transforming the variables: {list(transformations.keys())}") logger.debug( f"Total number of added or transformed variables: {len(additions) + len(transformations)}" ) # TODO: change assert to raise assert ( len(additions) + len(transformations) ) > 0, "No transformations or additions passed" if len(additions) > 0: # pyre-fixme[16]: Optional type has no attribute `assign`. added = all_data.assign(**additions).loc[:, list(additions.keys())] else: added = None if len(transformations) > 0: # NOTE: .copy(deep=False) is used to avoid a false alarm that sometimes happen. # When we take a slice of the DataFrame (all_data[k]), it is passed to a function # inside v from transformations (e.g.: fct_lump), it would then sometimes raise: # SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. # Adding .copy(deep=False) solves this. # See: https://stackoverflow.com/a/54914752 transformed = pd.DataFrame( # pyre-fixme[16]: Optional type has no attribute `copy`. {k: v(all_data.copy(deep=False)[k]) for k, v in transformations.items()} ) else: transformed = None out = pd.concat((added, transformed), axis=1) dropped_columns = list(set(all_data.columns.values) - set(out.columns.values)) if len(dropped_columns) > 0: if drop: logger.warning(f"Dropping the variables: {dropped_columns}") else: # pyre-fixme[16]: Optional type has no attribute `loc`. out = pd.concat((out, all_data.loc[:, dropped_columns]), axis=1) logger.info(f"Final variables in output: {list(out.columns)}") for column in out: logger.debug( f"Frequency table of column {column}:\n{out[column].value_counts(dropna=False)}" ) logger.debug( f"Number of levels of column {column}:\n{out[column].nunique(dropna=False)}" ) res = tuple(out[i:j] for (i, j) in boundaries) res = tuple(x.set_index(i) for x, i in zip(res, indices)) return res
def _find_adjustment_method( method: Literal["cbps", "ipw", "null", "poststratify", "rake"], WEIGHTING_METHODS: Dict[str, Callable] = BALANCE_WEIGHTING_METHODS, ) -> Callable: """This function translates a string method argument to the function itself. Args: method (Literal["cbps", "ipw", "null", "poststratify", "rake"]): method for adjustment: cbps, ipw, null, poststratify WEIGHTING_METHODS (Dict[str, Callable]): A dict where keys are strings of function names, and the values are the functions themselves. Returns: Callable: The function for adjustment """ if method in WEIGHTING_METHODS.keys(): adjustment_function = WEIGHTING_METHODS[method] else: raise ValueError(f"Unknown adjustment method: '{method}'") return adjustment_function