# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from typing import Iterable, TYPE_CHECKING
import numpy as np
import pandas as pd
from balance.stats_and_plots.weights_stats import _check_weights_are_valid
from balance.utils.input_validation import _coerce_to_numeric_and_validate
from scipy import stats
if TYPE_CHECKING:
from balance.sample_class import Sample
def _prepare_outcome_and_weights(
y: Iterable[float] | pd.Series | np.ndarray,
w0: Iterable[float] | pd.Series | np.ndarray,
w1: Iterable[float] | pd.Series | np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Validate and prepare outcome and weight arrays for comparison.
This function validates inputs, filters to observations with finite values,
and normalizes weights so that each weight vector has mean 1.
Args:
y: Outcome values.
w0: Baseline weights.
w1: Alternative weights.
Returns:
Tuple of (y_values, w0_normalized, w1_normalized) arrays with finite values
and weights normalized to mean 1.
Raises:
ValueError: If arrays have different lengths or no finite values.
"""
y_series = pd.Series(y)
w0_series = pd.Series(w0)
w1_series = pd.Series(w1)
if (
y_series.shape[0] != w0_series.shape[0]
or y_series.shape[0] != w1_series.shape[0]
):
raise ValueError(
"Outcome and weights must have the same number of observations."
)
_check_weights_are_valid(w0_series)
_check_weights_are_valid(w1_series)
not_null_mask = y_series.notna() & w0_series.notna() & w1_series.notna()
y_series = y_series[not_null_mask]
w0_series = w0_series[not_null_mask]
w1_series = w1_series[not_null_mask]
y_numeric, w0_values = _coerce_to_numeric_and_validate(
y_series, w0_series.to_numpy(), "outcome"
)
_, w1_values = _coerce_to_numeric_and_validate(
y_series, w1_series.to_numpy(), "outcome"
)
finite_mask = (
np.isfinite(w0_values) & np.isfinite(w1_values) & np.isfinite(y_numeric)
)
if not np.any(finite_mask):
raise ValueError("Outcome and weights must contain at least one finite value.")
w0_finite = w0_values[finite_mask]
w1_finite = w1_values[finite_mask]
w0_normalized = w0_finite / np.mean(w0_finite)
w1_normalized = w1_finite / np.mean(w1_finite)
return y_numeric[finite_mask], w0_normalized, w1_normalized
[docs]
def weights_impact_on_outcome_ss(
y: Iterable[float] | pd.Series | np.ndarray,
w0: Iterable[float] | pd.Series | np.ndarray,
w1: Iterable[float] | pd.Series | np.ndarray,
method: str = "t_test",
conf_level: float = 0.95,
) -> pd.Series:
"""
Evaluate whether weighting changes the outcome by testing y*w0 vs y*w1.
Note: Weights are normalized to have mean 1 before computing the weighted
products. In the balance package, weights are typically normalized to sum
to the sample size, so this additional normalization ensures comparability.
Args:
y: Outcome values.
w0: Baseline weights.
w1: Alternative weights.
method: Statistical test to use ("t_test").
conf_level: Confidence level for the mean difference interval.
Returns:
pd.Series: Summary statistics for the weighted outcome comparison.
Examples:
.. code-block:: python
import pandas as pd
from balance.stats_and_plots.impact_of_weights_on_outcome import (
weights_impact_on_outcome_ss,
)
result = weights_impact_on_outcome_ss(
y=pd.Series([1.0, 2.0, 3.0, 4.0]),
w0=pd.Series([1.0, 1.0, 1.0, 1.0]),
w1=pd.Series([1.0, 2.0, 1.0, 2.0]),
method="t_test",
)
print(result.round(3).to_string())
.. code-block:: text
mean_yw0 2.500
mean_yw1 4.000
mean_diff 1.500
diff_ci_lower -1.547
diff_ci_upper 4.547
t_stat 1.567
p_value 0.215
n 4.000
"""
if method != "t_test":
raise ValueError(f"Unsupported method: {method}")
if conf_level <= 0 or conf_level >= 1:
raise ValueError("conf_level must be between 0 and 1.")
y_values, w0_values, w1_values = _prepare_outcome_and_weights(y, w0, w1)
yw0 = y_values * w0_values
yw1 = y_values * w1_values
diff = yw1 - yw0
n_obs = int(diff.shape[0])
diff_std = float(np.std(diff, ddof=1)) if n_obs > 1 else 0.0
mean_yw0 = float(np.mean(yw0))
mean_yw1 = float(np.mean(yw1))
mean_diff = float(np.mean(diff))
if n_obs < 2:
t_stat, p_value = np.nan, np.nan
ci_lower, ci_upper = np.nan, np.nan
elif np.isclose(diff_std, 0.0):
t_stat, p_value = np.nan, np.nan
ci_lower, ci_upper = mean_diff, mean_diff
else:
t_stat, p_value = stats.ttest_rel(yw1, yw0, nan_policy="omit")
t_crit = stats.t.ppf((1 + conf_level) / 2, df=n_obs - 1)
margin = t_crit * diff_std / np.sqrt(n_obs)
ci_lower, ci_upper = mean_diff - margin, mean_diff + margin
return pd.Series(
{
"mean_yw0": mean_yw0,
"mean_yw1": mean_yw1,
"mean_diff": mean_diff,
"diff_ci_lower": ci_lower,
"diff_ci_upper": ci_upper,
"t_stat": float(t_stat) if np.isfinite(t_stat) else np.nan,
"p_value": float(p_value) if np.isfinite(p_value) else np.nan,
"n": n_obs,
}
)
def _validate_adjusted_samples(
adjusted0: "Sample",
adjusted1: "Sample",
) -> tuple["pd.DataFrame", "pd.DataFrame"]:
"""Validate adjusted Samples and return their outcome model matrices.
Args:
adjusted0: First adjusted Sample.
adjusted1: Second adjusted Sample.
Returns:
Tuple of (outcomes0_model_matrix, outcomes1_model_matrix).
Raises:
ValueError: If inputs are not Samples, not adjusted, missing outcomes,
or have mismatched outcome columns.
"""
from balance.sample_class import Sample
if not isinstance(adjusted0, Sample) or not isinstance(adjusted1, Sample):
raise ValueError("compare_adjusted_weighted_outcome_ss expects Sample inputs.")
adjusted0._check_if_adjusted()
adjusted1._check_if_adjusted()
outcomes0 = adjusted0.outcomes()
outcomes1 = adjusted1.outcomes()
if outcomes0 is None or outcomes1 is None:
raise ValueError("Both Samples must include outcomes.")
y0 = outcomes0.model_matrix()
y1 = outcomes1.model_matrix()
if list(y0.columns) != list(y1.columns):
raise ValueError("Outcome columns must match between adjusted Samples.")
return y0, y1
def _align_samples_by_id(
adjusted0: "Sample",
adjusted1: "Sample",
y0: "pd.DataFrame",
y1: "pd.DataFrame",
) -> tuple["pd.DataFrame", np.ndarray, np.ndarray]:
"""Align samples by ID and extract weights for common observations.
Args:
adjusted0: First adjusted Sample.
adjusted1: Second adjusted Sample.
y0: Outcome model matrix from adjusted0.
y1: Outcome model matrix from adjusted1.
Returns:
Tuple of (y_aligned, weights0_aligned, weights1_aligned) where y_aligned
is the outcome DataFrame for common IDs with valid weights.
Raises:
ValueError: If IDs have duplicates, no common IDs exist, outcome values
differ, or no valid weights exist.
"""
ids0 = adjusted0.id_column.to_numpy()
ids1 = adjusted1.id_column.to_numpy()
if pd.Index(ids0).has_duplicates or pd.Index(ids1).has_duplicates:
raise ValueError("Samples must have unique ids to compare outcomes.")
y0_indexed = y0.set_index(adjusted0.id_column)
y1_indexed = y1.set_index(adjusted1.id_column)
common_ids = y0_indexed.index.intersection(y1_indexed.index)
if common_ids.empty:
raise ValueError("Samples do not share any common ids.")
y0_common = y0_indexed.loc[common_ids]
y1_common = y1_indexed.loc[common_ids]
if not y0_common.equals(y1_common):
raise ValueError(
"Outcome values differ between adjusted Samples for common ids."
)
weights0 = adjusted0.weight_column.to_numpy()
weights0_series = pd.Series(weights0, index=ids0)
weights0_aligned = weights0_series.reindex(common_ids).to_numpy()
weights1_series = pd.Series(
adjusted1.weight_column.to_numpy(),
index=ids1,
).reindex(common_ids)
mask = weights1_series.notna().to_numpy()
if not np.any(mask):
raise ValueError(
"Samples do not share any common ids with non-missing weights in adjusted1."
)
y_aligned = y0_common.loc[mask]
weights0_aligned = weights0_aligned[mask]
weights1_aligned = weights1_series.to_numpy()[mask]
return y_aligned, weights0_aligned, weights1_aligned
[docs]
def compare_adjusted_weighted_outcome_ss(
adjusted0: "Sample",
adjusted1: "Sample",
method: str = "t_test",
conf_level: float = 0.95,
round_ndigits: int | None = 3,
) -> pd.DataFrame:
"""
Compare two adjusted Samples by testing outcomes under each set of weights.
Args:
adjusted0: First adjusted Sample (w0).
adjusted1: Second adjusted Sample (w1).
method: Statistical test to use ("t_test").
conf_level: Confidence level for the mean difference interval.
round_ndigits: Optional rounding for numeric outputs.
Returns:
pd.DataFrame: Outcome-by-statistic table comparing weighted outcomes.
Examples:
.. code-block:: python
import pandas as pd
from balance.sample_class import Sample
from balance.stats_and_plots.impact_of_weights_on_outcome import (
compare_adjusted_weighted_outcome_ss,
)
sample = Sample.from_frame(
pd.DataFrame(
{
"id": [1, 2, 3],
"x": [0.1, 0.2, 0.3],
"weight": [1.0, 1.0, 1.0],
"outcome": [1.0, 2.0, 3.0],
}
),
id_column="id",
weight_column="weight",
outcome_columns=("outcome",),
)
target = Sample.from_frame(
pd.DataFrame(
{
"id": [4, 5, 6],
"x": [0.1, 0.2, 0.3],
"weight": [1.0, 1.0, 1.0],
"outcome": [1.0, 2.0, 3.0],
}
),
id_column="id",
weight_column="weight",
outcome_columns=("outcome",),
)
adjusted_a = sample.set_target(target).adjust(method="null")
adjusted_b = sample.set_target(target).adjust(method="null")
adjusted_b.set_weights(pd.Series([1.0, 2.0, 3.0], index=adjusted_b.df.index))
impact = compare_adjusted_weighted_outcome_ss(
adjusted_a, adjusted_b, round_ndigits=3
)
print(impact.to_string())
.. code-block:: text
mean_yw0 mean_yw1 mean_diff diff_ci_lower diff_ci_upper t_stat p_value n
outcome
outcome 2.0 4.667 2.667 -4.922 10.256 1.512 0.27 3.0
"""
y0, y1 = _validate_adjusted_samples(adjusted0, adjusted1)
y_aligned, weights0, weights1 = _align_samples_by_id(adjusted0, adjusted1, y0, y1)
results = {}
for column in y_aligned.columns:
results[column] = weights_impact_on_outcome_ss(
y_aligned[column].to_numpy(),
w0=weights0,
w1=weights1,
method=method,
conf_level=conf_level,
)
impact_df = pd.DataFrame(results).T
impact_df.index.name = "outcome"
if round_ndigits is not None:
numeric_cols = impact_df.select_dtypes(include=["number"]).columns
impact_df[numeric_cols] = impact_df[numeric_cols].round(round_ndigits)
return impact_df