Source code for balance.utils.pandas_utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import copy
import logging
import warnings
from typing import Any, Dict, NamedTuple

import numpy as np
import pandas as pd
import pandas.api.types as pd_types

logger: logging.Logger = logging.getLogger(__package__)

# TODO: Allow configuring this threshold globally if we need different sensitivity
HIGH_CARDINALITY_RATIO_THRESHOLD: float = 0.8


[docs] class HighCardinalityFeature(NamedTuple): column: str unique_count: int unique_ratio: float has_missing: bool
def _compute_cardinality_metrics(series: pd.Series) -> HighCardinalityFeature: """Compute cardinality metrics for a feature series. The function counts unique non-missing values and their proportion relative to non-missing rows, while also tracking whether any missing values are present. Args: series: Feature column to evaluate. Returns: HighCardinalityFeature: Metrics describing uniqueness and missingness. Example: >>> import pandas as pd >>> s = pd.Series(["a", "b", "c", None, "c"]) >>> _compute_cardinality_metrics(s) HighCardinalityFeature(column='', unique_count=3, unique_ratio=0.75, has_missing=True) """ non_missing = series.dropna() unique_count = int(non_missing.nunique()) if not non_missing.empty else 0 unique_ratio = ( float(unique_count) / float(len(non_missing)) if len(non_missing) > 0 else 0.0 ) return HighCardinalityFeature( column="", unique_count=unique_count, unique_ratio=unique_ratio, has_missing=series.isna().any(), ) def _detect_high_cardinality_features( df: pd.DataFrame, threshold: float = HIGH_CARDINALITY_RATIO_THRESHOLD, ) -> list[HighCardinalityFeature]: """Identify categorical columns whose non-missing values are mostly unique. A feature is flagged when the ratio of unique non-missing values to total non-missing rows meets or exceeds ``threshold``. Only categorical columns (object, category, string dtypes) are checked, as high cardinality in numeric columns is expected and not problematic. Results are sorted by descending unique counts for clearer reporting. Args: df: Dataframe containing candidate features. threshold: Minimum unique-to-count ratio to flag a column. Returns: list[HighCardinalityFeature]: High-cardinality categorical columns sorted by descending uniqueness. Example: >>> import pandas as pd >>> df = pd.DataFrame({"id": ["a", "b", "c"], "group": ["a", "a", "b"]}) >>> _detect_high_cardinality_features(df, threshold=0.8) [HighCardinalityFeature(column='id', unique_count=3, unique_ratio=1.0, has_missing=False)] """ high_cardinality_features: list[HighCardinalityFeature] = [] for column in df.columns: # Only check categorical columns (object, category, string dtypes) if not _is_categorical_dtype(df[column]): continue metrics = _compute_cardinality_metrics(df[column]) if metrics.unique_count == 0: continue if metrics.unique_ratio < threshold: continue high_cardinality_features.append( HighCardinalityFeature( column=column, unique_count=metrics.unique_count, unique_ratio=metrics.unique_ratio, has_missing=metrics.has_missing, ) ) high_cardinality_features.sort( key=lambda feature: feature.unique_count, reverse=True ) return high_cardinality_features def _coerce_scalar(value: Any) -> float: """Safely convert a scalar value to ``float`` for diagnostics. ``None`` and non-scalar inputs are converted to ``NaN``. Scalar inputs are coerced to ``float`` when possible; otherwise, ``NaN`` is returned instead of raising a ``TypeError`` or ``ValueError``. Arrays and sequences return ``NaN`` so callers do not need to special-case these inputs. Args: value: Candidate value to coerce. Returns: float: ``float`` representation of ``value`` when possible, otherwise ``NaN``. Example: >>> _coerce_scalar(3) 3.0 >>> _coerce_scalar("7.125") 7.125 >>> _coerce_scalar(True) 1.0 >>> _coerce_scalar(complex(1, 2)) nan >>> _coerce_scalar(()) nan >>> _coerce_scalar([1, 2, 3]) nan """ if value is None: return float("nan") if np.isscalar(value): try: return float(value) except (TypeError, ValueError): return float("nan") return float("nan") def _sorted_unique_categories(values: pd.Series) -> list[Any]: """ Return sorted unique non-null category values for a Series. Args: values (pd.Series): Input series of categorical-like values. Returns: List[Any]: Sorted unique values. If no non-null values exist, returns an empty list. Examples: .. code-block:: python import pandas as pd from balance.utils.pandas_utils import _sorted_unique_categories _sorted_unique_categories(pd.Series(["b", "a", None])) # ['a', 'b'] """ uniques = pd.unique(values.dropna()) if len(uniques) == 0: return [] try: return sorted(uniques) except TypeError: return sorted(uniques, key=str) def _is_categorical_dtype(series: pd.Series) -> bool: """Check if a pandas Series has a categorical dtype. A dtype is considered categorical if it is object, category, or string type. Args: series: A pandas Series to check the dtype of. Returns: bool: True if the Series dtype is categorical (object, category, or string), False otherwise. Example: >>> import pandas as pd >>> _is_categorical_dtype(pd.Series(["a", "b"])) True >>> _is_categorical_dtype(pd.Series([1, 2])) False """ dtype = series.dtype return ( pd_types.is_object_dtype(dtype) or isinstance(dtype, pd.CategoricalDtype) or pd_types.is_string_dtype(dtype) ) def _process_series_for_missing_mask(series: pd.Series) -> pd.Series: """ Helper function to process a pandas Series for missing value detection while avoiding deprecation warnings from replace and infer_objects. Args: series (pd.Series): Input series to process Returns: pd.Series: Boolean series indicating missing values """ # Use _safe_replace_and_infer to avoid downcasting warnings replaced_series = _safe_replace_and_infer(series, [np.inf, -np.inf], np.nan) return replaced_series.isna() def _safe_replace_and_infer( data: pd.Series | pd.DataFrame, to_replace: Any | None = None, value: Any | None = None, ) -> pd.Series | pd.DataFrame: """ Helper function to safely replace values and infer object dtypes while avoiding pandas deprecation warnings. Args: data: pandas Series or DataFrame to process to_replace: Value(s) to replace (default: [np.inf, -np.inf]) value: Value to replace with (default: np.nan) Returns: Processed Series or DataFrame with proper dtype inference """ if to_replace is None: to_replace = [np.inf, -np.inf] if value is None: value = np.nan with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Downcasting behavior in `replace` is deprecated.*", category=FutureWarning, ) return data.replace(to_replace, value).infer_objects(copy=False) def _safe_fillna_and_infer( data: pd.Series | pd.DataFrame, value: Any | None = None ) -> pd.Series | pd.DataFrame: """ Helper function to safely fill NaN values and infer object dtypes while avoiding pandas deprecation warnings. Args: data: pandas Series or DataFrame to process value: Value to fill NaN with (default: np.nan) Returns: Processed Series or DataFrame with proper dtype inference """ if value is None: value = np.nan # Suppress pandas FutureWarnings about downcasting during fillna operations with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) filled_data = data.fillna(value) return filled_data.infer_objects(copy=False) def _safe_groupby_apply( data: pd.DataFrame, groupby_cols: str | list[str], apply_func: Any, ) -> pd.Series: """ Helper function to safely apply groupby operations while handling the include_groups parameter for pandas compatibility. Args: data: DataFrame to group groupby_cols: Column(s) to group by apply_func: Function to apply to each group Returns: Result of groupby apply operation """ # Use include_groups=False to avoid FutureWarning about operating on grouping columns # Fall back to old behavior if include_groups parameter is not supported try: return data.groupby(groupby_cols, include_groups=False).apply(apply_func) except TypeError: # Suppress pandas FutureWarnings about downcasting during fillna operations with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) # Fallback for older pandas versions that don't support include_groups parameter return data.groupby(groupby_cols).apply(apply_func) def _safe_show_legend(axis: Any) -> None: """ Helper function to safely show legend only if there are labeled artists, avoiding matplotlib UserWarning about no artists with labels. Args: axis: matplotlib axis object """ _, labels = axis.get_legend_handles_labels() if labels: axis.legend() def _safe_divide_with_zero_handling(numerator: Any, denominator: Any) -> Any: """ Helper function to safely perform division while handling divide by zero warnings with proper numpy error state management. Args: numerator: Numerator for division denominator: Denominator for division Returns: Result of division with proper handling of divide by zero cases """ with np.errstate(divide="ignore", invalid="ignore"): # Use numpy.divide to handle division by zero properly result = np.divide(numerator, denominator) return result def _dict_intersect(d: Dict[Any, Any], d_for_keys: Dict[Any, Any]) -> Dict[Any, Any]: """Returns dict 1, but only with the keys that are also in d2 Args: d1 (Dict): First dictionary. d2 (Dict): Second dictionary. Returns: Dict: Intersection of d1 and d2 (with values from d1) Examples: :: d1 = {"a": 1, "b": 2} d2 = {"c": 3, "b": 2} _dict_intersect(d1, d2) # {'b': 2} """ intersect_keys = d.keys() & d_for_keys.keys() return {k: d[k] for k in intersect_keys} # TODO: using _astype_in_df_from_dtypes to turn sample.df to original df dtypes may not be a good long term solution. # A better solution might require a redesign of some core features. def _astype_in_df_from_dtypes( df: pd.DataFrame, target_dtypes: pd.Series ) -> pd.DataFrame: """Returns df with dtypes cast as specified in df_orig. Columns that were not in the original dataframe are kept the same. Args: df (pd.DataFrame): df to convert target_dtypes (pd.Series): DataFrame.dtypes to use as target dtypes for conversion Returns: pd.DataFrame: df with dtypes cast as specified in target_dtypes Examples: :: df = pd.DataFrame({"id": ("1", "2"), "a": (1.0, 2.0), "weight": (1.0,2.0)}) df_orig = pd.DataFrame({"id": (1, 2), "a": (1, 2), "forest": ("tree", "banana")}) df.dtypes.to_dict() # {'id': dtype('O'), 'a': dtype('float64'), 'weight': dtype('float64')} df_orig.dtypes.to_dict() # {'id': dtype('int64'), 'a': dtype('int64'), 'forest': dtype('O')} target_dtypes = df_orig.dtypes _astype_in_df_from_dtypes(df, target_dtypes).dtypes.to_dict() # {'id': dtype('int64'), 'a': dtype('int64'), 'weight': dtype('float64')} """ dict_of_target_dtypes = _dict_intersect( target_dtypes.to_dict(), df.dtypes.to_dict(), ) return df.astype(dict_of_target_dtypes) def _are_dtypes_equal( dt1: pd.Series, dt2: pd.Series ) -> Dict[str, bool | pd.Series | set[Any]]: """Returns True if both dtypes are the same and False otherwise. If dtypes have an unequal set of items, the comparison will only be about the same set of keys. If there are no shared keys, then return False. Args: dt1 (pd.Series): first dtype (output from DataFrame.dtypes) dt2 (pd.Series): second dtype (output from DataFrame.dtypes) Returns: Dict[str, Union[bool, pd.Series, set]]: a dict of the following structure { 'is_equal': False, 'comparison_of_dtypes': flt True int False dtype: bool, 'shared_keys': {'flt', 'int'} } Examples: :: df1 = pd.DataFrame({'int':np.arange(5), 'flt':np.random.randn(5)}) df2 = pd.DataFrame({'flt':np.random.randn(5), 'int':np.random.randn(5)}) df11 = pd.DataFrame({'int':np.arange(5), 'flt':np.random.randn(5), 'miao':np.random.randn(5)}) _are_dtypes_equal(df1.dtypes, df1.dtypes)['is_equal'] # True _are_dtypes_equal(df1.dtypes, df2.dtypes)['is_equal'] # False _are_dtypes_equal(df11.dtypes, df2.dtypes)['is_equal'] # False """ shared_keys = set.intersection(set(dt1.keys()), set(dt2.keys())) shared_keys_list = list(shared_keys) comparison_of_dtypes = dt1[shared_keys_list] == dt2[shared_keys_list] is_equal = np.all(comparison_of_dtypes) return { "is_equal": is_equal, "comparison_of_dtypes": comparison_of_dtypes, "shared_keys": shared_keys, } def _warn_of_df_dtypes_change( original_df_dtypes: pd.Series, new_df_dtypes: pd.Series, original_str: str = "df", new_str: str = "new_df", ) -> None: """Prints a warning if the dtypes of some original df and some modified df differs. Args: original_df_dtypes (pd.Series): dtypes of original dataframe new_df_dtypes (pd.Series): dtypes of modified dataframe original_str (str, optional): string to use for warnings when referring to the original. Defaults to "df". new_str (str, optional): string to use for warnings when referring to the modified df. Defaults to "new_df". Examples: :: import numpy as np import pandas as pd from copy import deepcopy import balance df = pd.DataFrame({"int": np.arange(5), "flt": np.random.randn(5)}) new_df = deepcopy(df) new_df.int = new_df.int.astype(float) new_df.flt = new_df.flt.astype(int) balance.util._warn_of_df_dtypes_change(df.dtypes, new_df.dtypes) # WARNING (2023-02-07 08:01:19,961) [util/_warn_of_df_dtypes_change (line 1696)]: The dtypes of new_df were changed from the original dtypes of the input df, here are the differences - # WARNING (2023-02-07 08:01:19,963) [util/_warn_of_df_dtypes_change (line 1707)]: The (old) dtypes that changed for df (before the change): # WARNING (2023-02-07 08:01:19,966) [util/_warn_of_df_dtypes_change (line 1710)]: # flt float64 # int int64 # dtype: object # WARNING (2023-02-07 08:01:19,971) [util/_warn_of_df_dtypes_change (line 1711)]: The (new) dtypes saved in df (after the change): # WARNING (2023-02-07 08:01:19,975) [util/_warn_of_df_dtypes_change (line 1712)]: # flt int64 # int float64 # dtype: object """ compare_df_dtypes_before_and_after = _are_dtypes_equal( original_df_dtypes, new_df_dtypes ) if not compare_df_dtypes_before_and_after["is_equal"]: logger.warning( f"The dtypes of {new_str} were changed from the original dtypes of the input {original_str}, here are the differences - " ) compared_dtypes = compare_df_dtypes_before_and_after["comparison_of_dtypes"] dtypes_that_changed = ( # pyre-ignore[16]: we're only using the pd.Series, so no worries compared_dtypes[np.bitwise_not(compared_dtypes.values)] .keys() .to_list() ) logger.debug(compare_df_dtypes_before_and_after) logger.warning( f"The (old) dtypes that changed for {original_str} (before the change):" ) logger.warning("\n" + str(original_df_dtypes[dtypes_that_changed])) logger.warning(f"The (new) dtypes saved in {original_str} (after the change):") logger.warning("\n" + str(new_df_dtypes[dtypes_that_changed])) def _make_df_column_names_unique(df: pd.DataFrame) -> pd.DataFrame: """Make DataFrame column names unique by adding suffixes to duplicates. This function iterates through the column names of the input DataFrame and appends a suffix to duplicate column names to make them distinct. The suffix is an underscore followed by an integer value representing the number of occurrences of the column name. Args: df (pd.DataFrame): The input DataFrame with potentially duplicate column names. Returns: pd.DataFrame: A DataFrame with unique column names where any duplicate column names have been renamed with a suffix. Examples: :: import pandas as pd # Sample DataFrame with duplicate column names data = { "A": [1, 2, 3], "B": [4, 5, 6], "A2": [7, 8, 9], "C": [10, 11, 12], } df1 = pd.DataFrame(data) df1.columns = ["A", "B", "A", "C"] _make_df_column_names_unique(df1).to_dict() # {'A': {0: 1, 1: 2, 2: 3}, # 'B': {0: 4, 1: 5, 2: 6}, # 'A_1': {0: 7, 1: 8, 2: 9}, # 'C': {0: 10, 1: 11, 2: 12}} """ # Check if all column names are unique unique_columns = set(df.columns) if len(unique_columns) == len(df.columns): return df # Else: fix duplicate column names logger.warning( """Duplicate column names exists in the DataFrame. A suffix will be added to them but their order might change from one iteration to another. To avoid issues, make sure to change your original column names to be unique (and without special characters).""" ) col_counts = {} new_columns = [] for col in df.columns: if col in col_counts: col_counts[col] += 1 new_col_name = f"{col}_{col_counts[col]}" logger.warning( f"Column {col} already exists in the DataFrame, renaming it to be {new_col_name}" ) else: col_counts[col] = 0 new_col_name = col new_columns.append(new_col_name) df.columns = new_columns return df def _pd_convert_all_types( df: pd.DataFrame, input_type: str, output_type: str ) -> pd.DataFrame: """Converts columns in the input dataframe to a specified type. Args: df (pd.DataFrame): Input df input_type (str): A string of the input type to change. output_type (str): A string of the desired output type for the columns of type input_type. Returns: pd.DataFrame: Output df with columns converted from input_type to output_type. Examples: :: import numpy as np import pandas as pd df = pd.DataFrame({"a": pd.array([1,2], dtype = pd.Int64Dtype()), "a2": pd.array([1,2], dtype = np.int64)}) df.dtypes # a Int64 # a2 int64 # dtype: object df.dtypes.to_numpy() # array([Int64Dtype(), dtype('int64')], dtype=object) df2 =_pd_convert_all_types(df, "Int64", "int64") df2.dtypes.to_numpy() # array([dtype('int64'), dtype('int64')], dtype=object) # Might be requires some casting to float64 so that it will handle missing values # For details, see: https://stackoverflow.com/a/53853351 df3 =_pd_convert_all_types(df, "Int64", "float64") df3.dtypes.to_numpy() # array([dtype('float64'), dtype('float64')], dtype=object) """ df = copy.deepcopy(df) # source: https://stackoverflow.com/questions/39904889/ df = pd.concat( [ df.select_dtypes([], [input_type]), df.select_dtypes([input_type]).apply(pd.Series.astype, dtype=output_type), ], axis=1, ).reindex(df.columns, axis=1) return df