# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
"""BalanceFrame: workflow orchestrator for survey/observational data reweighting.
Pairs a responder SampleFrame with a target SampleFrame and exposes an
immutable adjust() method that returns a new, weight-augmented BalanceFrame.
"""
from __future__ import annotations
import collections
import copy
import logging
import warnings
from copy import deepcopy
from typing import Any, Callable, cast, Literal, overload, TYPE_CHECKING
import numpy as np
import pandas as pd
from balance import adjustment as balance_adjustment, util as balance_util
from balance.adjustment import _find_adjustment_method
from balance.csv_utils import to_csv_with_defaults
from balance.sample_frame import SampleFrame
from balance.stats_and_plots import weights_stats
from balance.summary_utils import _build_diagnostics, _build_summary
from balance.typing import FilePathOrBuffer
from balance.util import (
_assert_type,
_detect_high_cardinality_features,
HighCardinalityFeature,
)
from balance.utils.file_utils import _to_download
from balance.utils.model_matrix import build_design_matrix
if TYPE_CHECKING:
from typing import Self
from balance.balancedf_class import BalanceDFSource # noqa: F401
# The set of string method names accepted by _find_adjustment_method.
_AdjustmentMethodStr = Literal["cbps", "ipw", "null", "poststratify", "rake"]
logger: logging.Logger = logging.getLogger(__package__)
class _CallableBool:
"""A bool-like value that is also callable, for backward-compatible property migration.
This allows properties like ``has_target`` and ``is_adjusted`` to work
both as a property and as a method call::
# Both forms are equivalent:
if bf.has_target: # property-style (preferred)
...
if bf.has_target(): # method-call-style (backward compat)
...
This dual-use pattern was introduced so that code written against the
old ``Sample.has_target()`` method continues to work after the migration
to a property on ``BalanceFrame``.
Args:
value: The boolean value to wrap.
Examples:
>>> cb = _CallableBool(True)
>>> bool(cb)
True
>>> cb()
True
"""
__slots__ = ("_value",)
def __init__(self, value: bool) -> None:
self._value: bool = value
def __bool__(self) -> bool:
return self._value
def __call__(self) -> bool:
return self._value
def __repr__(self) -> str:
return repr(self._value)
def __eq__(self, other: object) -> bool:
if isinstance(other, bool):
return self._value == other
if isinstance(other, _CallableBool):
return self._value == other._value
return NotImplemented
def __hash__(self) -> int:
return hash(self._value)
def __mul__(self, other: object) -> object:
return self._value * other # pyre-ignore[58]
def __rmul__(self, other: object) -> object:
return other * self._value # pyre-ignore[58]
[docs]
class BalanceFrame:
"""A pairing of responder and target SampleFrames for survey weighting.
BalanceFrame holds two :class:`SampleFrame` instances — *responders*
(the sample to be reweighted) and *target* (the population benchmark) —
and provides methods for adjusting responder weights and computing
diagnostics.
BalanceFrame is **immutable by convention**: :meth:`adjust` returns a
*new* BalanceFrame rather than modifying the existing one. This makes
it safe to keep a reference to the pre-adjustment state.
Must be constructed via the public constructor
``BalanceFrame(sample=..., target=...)`` which delegates to the
internal :meth:`_create` factory.
Attributes:
responders (SampleFrame): The responder sample.
target (SampleFrame): The target population.
is_adjusted (bool): Whether :meth:`adjust` has been called.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.is_adjusted
False
>>> adjusted = bf.adjust(method="ipw")
>>> adjusted.is_adjusted
True
>>> bf.is_adjusted # original unchanged
False
"""
# pyre-fixme[13]: Attributes are initialized in _create() / from_frame()
_sf_sample_pre_adjust: SampleFrame
# pyre-fixme[13]: Attributes are initialized in _create() / from_frame()
_sf_sample: SampleFrame
# pyre-fixme[13]: Attributes are initialized in _create() / from_frame()
_sf_target: SampleFrame | None
# pyre-fixme[13]: Attributes are initialized in _create() / from_frame()
_adjustment_model: dict[str, Any] | None
# pyre-fixme[4]: Attributes are initialized in from_frame() / _create()
# _links is a defaultdict(list) but by convention stores single objects
# (not lists) for the "target" and "unadjusted" keys. The defaultdict
# type is kept for BalanceDF compatibility which expects .get() semantics.
# Values: _links["target"] → SampleFrame | BalanceFrame
# _links["unadjusted"] → BalanceFrame
_links = None
def _sync_sampleframe_state_from_responder(self, responder: SampleFrame) -> None:
"""Sync inherited SampleFrame fields from a responder SampleFrame.
This is only needed when ``self`` is also a ``SampleFrame`` (e.g.
``Sample`` via multiple inheritance), so inherited SampleFrame
properties stay consistent with ``_sf_sample``.
"""
if isinstance(self, SampleFrame):
self._df = responder._df
self._id_column_name = responder._id_column_name
self._column_roles = responder._column_roles
self._weight_column_name = responder._weight_column_name
self._weight_metadata = responder._weight_metadata
self._df_dtypes = responder._df_dtypes
@property
def _df_dtypes(self) -> pd.Series | None:
"""Original dtypes, delegated to ``_sf_sample._df_dtypes``."""
return self._sf_sample._df_dtypes
@_df_dtypes.setter
def _df_dtypes(self, value: pd.Series | None) -> None:
self._sf_sample._df_dtypes = value
@property
def id_series(self) -> pd.Series | None: # pyre-ignore[3]
"""The id column as a Series, delegated to ``_sf_sample``."""
return self._sf_sample.id_series
@property
def id_column(self) -> str | None: # pyre-ignore[3]
"""The id column name, delegated to ``_sf_sample``.
Changed in 0.20.0 to return the name (str) instead of data (pd.Series).
Use :attr:`id_series` for data.
"""
# TODO: remove this warning after 2026-06-01
warnings.warn(
"Note: id_column now returns the column name (str) since "
"balance 0.20.0. It previously returned ID data (pd.Series). "
"Use id_series for ID data.",
FutureWarning,
stacklevel=2,
)
return self._sf_sample._id_column_name
@property
def weight_series(self) -> pd.Series | None: # pyre-ignore[3]
"""The active weight as a Series, delegated to ``_sf_sample``."""
try:
return self._sf_sample.weight_series
except ValueError:
return None
# --- Property descriptors backed by _sf_sample ---
@property
def _df(self) -> pd.DataFrame: # pyre-ignore[3]
"""The internal DataFrame, delegated to ``_sf_sample._df``."""
return self._sf_sample._df
@_df.setter
def _df(self, value: pd.DataFrame | None) -> None: # pyre-ignore[2,3]
if value is None:
raise ValueError(
"Cannot set _df to None. A BalanceFrame must always have a "
"backing DataFrame."
)
self._sf_sample._df = value
@property
def _outcome_columns(self) -> pd.DataFrame | None:
"""Outcome columns as a DataFrame, delegated to ``_sf_sample``."""
outcome_cols = self._sf_sample._column_roles.get("outcomes", [])
if not outcome_cols:
return None
return self._sf_sample._df[outcome_cols]
@_outcome_columns.setter
def _outcome_columns(self, value: pd.DataFrame | None) -> None:
if value is None:
self._sf_sample._column_roles["outcomes"] = []
else:
self._sf_sample._column_roles["outcomes"] = value.columns.tolist()
@property
def _ignored_column_names(self) -> list[str]: # pyre-ignore[3]
"""Ignored column names, delegated to ``_sf_sample.ignored_columns``."""
return self._sf_sample._column_roles.get("ignored", [])
@_ignored_column_names.setter
def _ignored_column_names(
self, value: list[str] | None
) -> None: # pyre-ignore[2,3]
self._sf_sample._column_roles["ignored"] = list(value) if value else []
@property
def df_ignored(self) -> pd.DataFrame | None:
"""Ignored columns from the responder SampleFrame, or None."""
return self._sf_sample.df_ignored
# -----------------------------------------------------------------------
# Design note: Why __new__ + no-op __init__?
#
# copy.deepcopy() allocates the new instance by calling __new__(cls)
# with NO arguments. If __init__ held the real construction logic
# (with required *responders* and *target* parameters), deepcopy would
# raise TypeError before it could copy any attributes.
#
# The solution used here:
# - __new__ handles BOTH construction paths:
# 1. Public constructor: BalanceFrame(sample=sf1, target=sf2)
# → validates args and delegates to _create().
# 2. deepcopy path: BalanceFrame() (no args)
# → returns a bare object via object.__new__(cls); deepcopy
# then copies attributes onto it directly.
# - __init__ is intentionally a no-op (all state is set by _create()).
# -----------------------------------------------------------------------
def __new__(
cls,
sample: SampleFrame | None = None,
target: SampleFrame | None = None,
) -> BalanceFrame:
"""Create a BalanceFrame from responder and target SampleFrames.
This uses ``__new__`` so that the natural constructor syntax
``BalanceFrame(sample=..., target=...)`` works while still
routing through the validated :meth:`_create` factory.
Args:
sample: The responder / sample data.
target: The target / population data.
Returns:
A new BalanceFrame pairing the two samples.
Raises:
TypeError: If *sample* or *target* is not a SampleFrame.
ValueError: If *sample* and *target* share no covariate
columns.
"""
if sample is None:
# Allow object.__new__(cls) for copy.deepcopy() support.
if target is None:
return object.__new__(cls)
raise TypeError(
"BalanceFrame requires at least a 'sample' argument. "
"Usage: BalanceFrame(sample=sf1) or "
"BalanceFrame(sample=sf1, target=sf2)"
)
return cls._create(sample=sample, target=target)
def __init__(
self,
sample: SampleFrame | None = None,
target: SampleFrame | None = None,
) -> None:
# All initialisation happens in _create(); __init__ is intentionally
# empty so that __new__ + _create() handles everything.
pass
@classmethod
def _create(
cls,
sample: SampleFrame,
target: SampleFrame | None = None,
) -> Self:
"""Internal factory method.
Validates covariate overlap and builds the BalanceFrame instance.
Prefer the public constructor ``BalanceFrame(sample=..., target=...)``.
Args:
sample: The responder sample.
target: The target population. If None, creates a target-less
BalanceFrame that can be completed later via :meth:`set_target`.
Returns:
A validated BalanceFrame.
Raises:
TypeError: If *sample* or *target* is not a SampleFrame.
ValueError: If they share no covariate columns.
"""
if not isinstance(sample, SampleFrame):
raise TypeError(
f"'sample' must be a SampleFrame, got {type(sample).__name__}"
)
if target is not None and not isinstance(target, SampleFrame):
raise TypeError(
f"'target' must be a SampleFrame, got {type(target).__name__}"
)
instance = object.__new__(cls)
instance._sf_sample_pre_adjust = sample
instance._sf_sample = sample # same object initially
instance._sf_target = target
instance._adjustment_model = None
instance._links = collections.defaultdict(list)
if target is not None:
instance._links["target"] = target
# When the instance is also a SampleFrame (e.g., Sample inherits
# from both BalanceFrame and SampleFrame), copy SampleFrame state
# so that inherited SampleFrame properties work on the instance.
instance._sync_sampleframe_state_from_responder(sample)
# Validate covariate overlap using public properties
if target is not None:
cls._validate_covariate_overlap(sample, target)
return instance
@staticmethod
def _validate_covariate_overlap(
responders: SampleFrame, target: SampleFrame
) -> None:
"""Check that responders and target share at least one covariate.
When both have no covariates (outcome-only comparison), a warning
is issued instead of raising.
Raises:
ValueError: If both have covariates but share none.
"""
resp_covars = set(responders.covar_columns)
target_covars = set(target.covar_columns)
overlap = resp_covars & target_covars
if len(overlap) == 0:
if len(resp_covars) == 0 and len(target_covars) == 0:
# Both have no covariates — legitimate for outcome-only use.
logger.warning(
"Both responders and target have no covariate columns. "
"adjust() will not be available."
)
return
raise ValueError(
"Responders and target share no covariate columns. "
f"Responder covariates: {sorted(resp_covars)}, "
f"target covariates: {sorted(target_covars)}"
)
if overlap != resp_covars or overlap != target_covars:
logger.warning(
"Responders and target have different covariate columns. "
f"Using {len(overlap)} common variable(s): {sorted(overlap)}. "
f"Responder-only: {sorted(resp_covars - overlap)}, "
f"target-only: {sorted(target_covars - overlap)}."
)
# --- Properties ---
@property
def df_responders(self) -> pd.DataFrame:
"""The responder data as a DataFrame."""
return self._sf_sample.df
@property
def df_target(self) -> pd.DataFrame | None:
"""The target data as a DataFrame, or None if not yet set."""
if self._sf_target is None:
return None
return self._sf_target.df
@property
def df_responders_unadjusted(self) -> pd.DataFrame:
"""The original (pre-adjustment) responder data as a DataFrame."""
return self._sf_sample_pre_adjust.df
# --- Backward-compat aliases (to be removed in a future diff) ---
@property
def responders(self) -> SampleFrame:
"""Alias for ``_sf_sample`` (backward compat, will be removed)."""
return self._sf_sample
@property
def target(self) -> SampleFrame | None:
"""Alias for ``_sf_target`` (backward compat, will be removed)."""
return self._sf_target
@property
def unadjusted(self) -> SampleFrame | None:
"""Alias for ``_sf_sample_pre_adjust`` if adjusted, else None (backward compat)."""
if self.is_adjusted:
return self._sf_sample_pre_adjust
return None
@property
def has_target(self) -> _CallableBool:
"""Whether this BalanceFrame has a target population set.
Returns a dual-use ``_CallableBool``: both ``bf.has_target`` and
``bf.has_target()`` work (the latter for backward compatibility).
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp)
>>> bf.has_target
False
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf.set_target(tgt)
>>> bf.has_target
True
"""
return _CallableBool(
self._sf_target is not None
or (self._links is not None and "target" in self._links)
)
[docs]
def set_target(
self, target: BalanceFrame | SampleFrame, inplace: bool | None = None
) -> Self:
"""Set or replace the target population.
When *target* is a BalanceFrame (or subclass such as Sample), a deep
copy of ``self`` is returned with the target set (immutable pattern).
When *target* is a raw SampleFrame, the behaviour depends on
*inplace*: True mutates self, False returns a new BalanceFrame.
Args:
target: The target population — a BalanceFrame/Sample or a
SampleFrame.
inplace: If True, mutates self (only valid for SampleFrame
targets). If False, returns a new copy. Defaults to None
which auto-selects: copy for BalanceFrame targets, inplace
for SampleFrame targets.
Returns:
BalanceFrame with the new target set.
Raises:
TypeError / ValueError: If *target* is not a BalanceFrame or
SampleFrame, or if they share no covariate columns.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp)
>>> bf.set_target(tgt)
>>> bf.has_target()
True
"""
if isinstance(target, BalanceFrame):
# BalanceFrame / Sample path: return a deep copy (immutable)
new_copy = deepcopy(self)
new_copy._links["target"] = target
BalanceFrame._validate_covariate_overlap(
new_copy._sf_sample, target._sf_sample
)
new_copy._sf_target = target._sf_sample
return new_copy
if isinstance(target, SampleFrame):
# SampleFrame path: default inplace=True for backward compat
if inplace is None:
inplace = True
BalanceFrame._validate_covariate_overlap(self._sf_sample, target)
if inplace:
if self.is_adjusted:
logger.warning(
"Replacing target on an adjusted object resets responder "
"weights to pre-adjust values and discards current "
"adjustment results. Pass inplace=False to return a new "
"object and keep the current adjusted state on this "
"instance."
)
self._sf_target = target
self._links["target"] = target
# Reset adjustment state — old adjustment is no longer valid.
self._sf_sample = self._sf_sample_pre_adjust
self._adjustment_model = None
self._links.pop("unadjusted", None)
self._sync_sampleframe_state_from_responder(self._sf_sample)
return self
else:
return type(self)._create(
sample=copy.deepcopy(self._sf_sample_pre_adjust),
target=target,
)
raise TypeError("A target, a Sample object, must be specified")
[docs]
def set_as_pre_adjust(self, *, inplace: bool = False) -> Self:
"""Set the current responder state as the new pre-adjust baseline.
This "locks in" the current responder weights (which may already be
adjusted and/or trimmed) as the baseline for future diagnostics and
subsequent adjustments.
Args:
inplace: If True, mutate this object and return it. If False
(default), return a new object with a deep-copied responder
frame and reset baseline.
Returns:
BalanceFrame with ``_sf_sample_pre_adjust`` reset to the current
responder SampleFrame state. In copy mode (``inplace=False``),
only the responder frame is deep-copied and used to construct a new
object (the full ``_links`` graph is not deep-copied). In inplace
mode, the baseline is set to the existing responder
frame object so baseline/current share identity, matching
unadjusted-object semantics elsewhere in the API.
Any current adjustment model is cleared because the object is no
longer considered adjusted after this operation.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> adjusted = BalanceFrame(sample=resp, target=tgt).adjust(method="null")
>>> baseline_locked = adjusted.set_as_pre_adjust() # copy mode
>>> baseline_locked.is_adjusted
False
>>> _ = adjusted.set_as_pre_adjust(inplace=True) # inplace mode
"""
if inplace:
bf = self
frozen = bf._sf_sample
else:
frozen = copy.deepcopy(self._sf_sample)
bf = type(self)._create(sample=frozen, target=self._sf_target)
# Preserve a richer target link (e.g., BalanceFrame/Sample object)
# when present on the original.
if "target" in self._links:
bf._links["target"] = self._links["target"]
bf._sf_sample_pre_adjust = frozen
bf._sf_sample = frozen
bf._adjustment_model = None
bf._links.pop("unadjusted", None)
bf._sync_sampleframe_state_from_responder(frozen)
return bf
@property
def is_adjusted(self) -> _CallableBool:
"""Whether this BalanceFrame has been adjusted.
Returns a ``_CallableBool`` so both ``bf.is_adjusted`` (property)
and ``bf.is_adjusted()`` (legacy call) work.
For compound adjustments (calling ``adjust()`` multiple times),
``is_adjusted`` is True after the first adjustment and remains True
for all subsequent adjustments. The original unadjusted baseline is
always preserved in ``_sf_sample_pre_adjust``.
"""
return _CallableBool(self._sf_sample is not self._sf_sample_pre_adjust)
# --- Adjustment ---
def _resolve_adjustment_function(
self, method: str | Callable[..., Any]
) -> Callable[..., Any]:
"""Resolve a weighting method string or callable to a function.
Args:
method: A string naming a built-in method or a callable.
Returns:
The resolved adjustment function.
Raises:
ValueError: If *method* is not a valid string or callable.
"""
if isinstance(method, str):
return _find_adjustment_method(cast(_AdjustmentMethodStr, method))
if callable(method):
return method
raise ValueError(
"'method' must be a string naming a weighting method or a callable"
)
def _get_covars(
self,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Get covariate DataFrames for responders and target.
Returns:
A (responder_covars, target_covars) tuple of DataFrames.
Raises:
ValueError: If no target is set.
"""
if self._sf_target is None:
raise ValueError("Cannot get covars without a target population.")
return self._sf_sample.df_covars, self._sf_target.df_covars
def _build_adjusted_frame(
self,
result: dict[str, Any],
method: str | Callable[..., Any],
) -> Self:
"""Construct a new BalanceFrame with adjusted weights.
Args:
result: The dict returned by the weighting function, containing
at least ``"weight"`` and optionally ``"model"``.
method: The original method argument (string or callable).
Returns:
A new, adjusted BalanceFrame (or subclass instance if called on
a subclass).
"""
new_responders = copy.deepcopy(self._sf_sample)
method_name = (
method
if isinstance(method, str)
else getattr(method, "__name__", str(method))
)
# --- Unified weight history tracking ---
#
# After each adjustment the SampleFrame accumulates weight columns:
#
# | After | Weight columns | Active |
# |--------------|---------------------------------------------------|----------|
# | Before adj. | weight | weight |
# | 1st adjust | weight, weight_pre_adjust, weight_adjusted_1 | weight |
# | 2nd adjust | weight, weight_pre_adjust, weight_adjusted_1, _2 | weight |
# | 3rd adjust | ... weight_adjusted_1, _2, _3 | weight |
#
# * weight_pre_adjust — frozen copy of original design weights (first adj only)
# * weight_adjusted_N — output of the Nth adjustment step
# * weight — always overwritten with the latest adjusted values
original_weight_name = str(_assert_type(self.weight_series).name)
# On first adjustment: freeze the original design weights as
# "weight_pre_adjust" so the full history is in one SampleFrame.
if "weight_pre_adjust" not in new_responders._df.columns:
new_responders.add_weight_column(
"weight_pre_adjust",
new_responders._df[original_weight_name].copy(),
)
# Find next global action number (shared counter across adjusted/trimmed).
n = new_responders._next_weight_action_number()
adj_col_name = f"weight_adjusted_{n}"
# Add the new adjusted weights as weight_adjusted_N
new_responders.add_weight_column(
adj_col_name,
result["weight"],
metadata={
"method": method_name,
"adjusted": True,
"model": result.get("model", {}),
},
)
# Overwrite the original weight column with the new adjusted values,
# so the active weight column always keeps its original name.
# use_index=True lets na_action="drop" (which returns fewer weights)
# fill dropped rows with NaN; set_weights warns about missing indices.
new_responders.set_weights(result["weight"], use_index=True)
# TODO: The weight history columns (weight_pre_adjust, weight_adjusted_1,
# weight_adjusted_2, ...) make _sf_sample_pre_adjust redundant. Once all
# consumers are updated to read weight_pre_adjust instead,
# _sf_sample_pre_adjust can be removed entirely.
# Use type(self) so subclasses (e.g. Sample) get their own type back.
new_bf = type(self)._create(
sample=new_responders,
target=self._sf_target,
)
# Point _sf_sample_pre_adjust to the original (pre-adjustment) data.
# For compound adjustments this is always the *very first* baseline,
# so diagnostics (asmd_improvement, summary) show total improvement.
new_bf._sf_sample_pre_adjust = self._sf_sample_pre_adjust
# Always link back to the original unadjusted BalanceFrame so that
# 3-way comparisons (adjusted vs original vs target) span the full
# adjustment chain, not just the last step.
if "unadjusted" in self._links:
new_bf._links["unadjusted"] = self._links["unadjusted"]
else:
new_bf._links["unadjusted"] = self
if "target" in self._links:
new_bf._links["target"] = self._links["target"]
raw_model = result.get("model")
# Defensive copy: the weighting function may retain a reference to the
# dict it returned, so mutating it here could cause surprising side effects.
# TODO: Track adjustment history — currently only the latest model is
# stored. A future enhancement should maintain a list of
# (method, model_dict) tuples for each adjustment step.
new_bf._adjustment_model = (
dict(raw_model) if isinstance(raw_model, dict) else raw_model
)
# Preserve the raw model's method name (e.g. "null_adjustment") when
# present; only set a fallback when the model doesn't include one.
adj_model = new_bf._adjustment_model
if isinstance(adj_model, dict):
adj_model.setdefault("method", method_name)
if adj_model.get("method") == "ipw":
# Preserve training-time design weights only when the weighting
# method already opted into fit metadata.
fit_sample_weights = adj_model.get("fit_sample_weights")
if isinstance(fit_sample_weights, pd.Series):
adj_model.setdefault(
"training_sample_weights",
fit_sample_weights,
)
if new_bf._sf_target is not None:
fit_target_weights = adj_model.get("fit_target_weights")
if isinstance(fit_target_weights, pd.Series):
adj_model.setdefault(
"training_target_weights",
fit_target_weights,
)
return new_bf
[docs]
def adjust(
self,
target: BalanceFrame | None = None,
method: str | Callable[..., Any] = "ipw",
*args: Any,
**kwargs: Any,
) -> Self:
"""Adjust responder weights to match the target. Returns a NEW BalanceFrame.
The original BalanceFrame is not modified (immutable pattern). The
returned BalanceFrame has ``is_adjusted == True`` and the pre-adjustment
responders stored in :attr:`unadjusted`.
The active weight column always keeps its original name (e.g.,
``"weight"``). Its values are overwritten with the new adjusted
weights. The full weight history is tracked via additional columns:
.. list-table:: Weight columns after each adjustment
:header-rows: 1
* - After
- Weight columns in ``responders``
- Active (``"weight"``)
* - Before adjust
- ``weight``
- original design weights
* - 1st adjust
- ``weight``, ``weight_pre_adjust``, ``weight_adjusted_1``
- = ``weight_adjusted_1`` values
* - 2nd adjust
- + ``weight_adjusted_2``
- = ``weight_adjusted_2`` values
* - 3rd adjust
- + ``weight_adjusted_3``
- = ``weight_adjusted_3`` values
**Compound / sequential adjustments:** ``adjust()`` can be called
multiple times. Each call uses the *current* (previously adjusted)
weights as design weights, so adjustments compound. For example, run
IPW first to correct broad imbalances, then rake on a specific variable
for fine-tuning::
adjusted_ipw = bf.adjust(method="ipw", max_de=2)
adjusted_final = adjusted_ipw.adjust(method="rake")
The original unadjusted baseline is always preserved:
* ``_sf_sample_pre_adjust`` always points to the **original**
(pre-first-adjustment) SampleFrame.
* ``_links["unadjusted"]`` always points to the **original**
unadjusted BalanceFrame, so 3-way comparisons
(adjusted vs original vs target) and ``asmd_improvement()`` show
**total** improvement across all adjustment steps.
* ``model`` stores only the **latest** adjustment's model dict.
Args:
target: Optional target BalanceFrame/Sample. If provided, calls
``set_target(target)`` first, then adjusts. If None, uses the
already-set target.
method: The weighting method to use. Built-in options:
``"ipw"``, ``"cbps"``, ``"rake"``, ``"poststratify"``,
``"null"``. A callable with the same signature as the
built-in methods is also accepted.
*args: Positional arguments (forwarded on recursive call only).
**kwargs: Additional keyword arguments forwarded to the adjustment
function (e.g. ``max_de``, ``transformations``).
Returns:
A new, adjusted BalanceFrame.
Raises:
ValueError: If *method* is a string that doesn't match any
registered adjustment method.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2, 3], "x": [10.0, 20.0, 30.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [4, 5, 6], "x": [15.0, 25.0, 35.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> adjusted = bf.adjust(method="ipw")
>>> adjusted.is_adjusted
True
>>> adjusted2 = adjusted.adjust(method="null")
>>> adjusted2.is_adjusted
True
"""
if target is not None:
# Inline target: set it first, then recurse
self_with_target = self.set_target(target)
return self_with_target.adjust(*args, method=method, **kwargs)
self._require_target()
sf_target = self._sf_target
assert sf_target is not None # guaranteed by _require_target() above
adjustment_function = self._resolve_adjustment_function(method)
resp_covars, target_covars = self._get_covars()
# Detect high-cardinality features in both responder and target covariates
num_rows_sample = resp_covars.shape[0]
num_rows_target = target_covars.shape[0]
if (
num_rows_sample > 0
and num_rows_target > 100_000
and num_rows_target >= 10 * num_rows_sample
):
logger.warning(
"Large target detected: %s target rows vs %s sample rows. "
"When the target is much larger than the sample (here >10x and >100k rows), "
"the target's contribution to variance becomes negligible. "
"Standard errors will be driven almost entirely by the sample, "
"similar to a one-sample inference setting.",
num_rows_target,
num_rows_sample,
)
sample_high_card = _detect_high_cardinality_features(resp_covars)
target_high_card = _detect_high_cardinality_features(target_covars)
# Merge the results, taking the maximum unique_count for each column
high_cardinality_dict: dict[str, HighCardinalityFeature] = {}
for feature in sample_high_card + target_high_card:
if (
feature.column not in high_cardinality_dict
or feature.unique_count
> high_cardinality_dict[feature.column].unique_count
):
high_cardinality_dict[feature.column] = feature
high_cardinality_features = sorted(
high_cardinality_dict.values(),
key=lambda f: f.unique_count,
reverse=True,
)
if high_cardinality_features:
formatted_details = ", ".join(
f"{feature.column} (unique={feature.unique_count}; "
f"unique_ratio={feature.unique_ratio:.2f}"
f"{'; missing values present' if feature.has_missing else ''}"
f")"
for feature in high_cardinality_features
)
logger.warning(
"High-cardinality features detected that may not provide signal: "
+ formatted_details
)
result = adjustment_function(
sample_df=resp_covars,
sample_weights=self._sf_sample.df_weights.iloc[:, 0],
target_df=target_covars,
target_weights=sf_target.df_weights.iloc[:, 0],
**kwargs,
)
return self._build_adjusted_frame(result, method)
[docs]
def fit(
self,
*,
target: BalanceFrame | SampleFrame | None = None,
method: str | Callable[..., Any] = "ipw",
inplace: bool = True,
**kwargs: Any,
) -> Self:
"""Fit a weighting model and return the fitted BalanceFrame.
This is the sklearn-style entry point for survey weight adjustment.
Like sklearn's ``fit()``, it learns model parameters, mutates
``self`` (by default), and returns ``self``. In survey weighting,
fitting the propensity model inherently produces adjusted weights
(the two are inseparable), so the returned object contains both the
fitted model and the adjusted weights — analogous to how
``KMeans.fit()`` stores ``labels_`` on the fitted object.
**Workflow — basic fitting (sklearn-style, inplace=True):**
.. code-block:: python
bf = BalanceFrame(sample=respondents, target=population)
bf.fit(method="ipw") # mutates bf, returns bf
bf.weights().df # the adjusted weights
**Workflow — functional style (inplace=False):**
.. code-block:: python
adjusted = bf.fit(method="ipw", inplace=False)
**Workflow — fit on subset, apply to holdout:**
.. code-block:: python
fitted = train_bf.fit(method="ipw")
scored = holdout_bf.set_fitted_model(fitted, inplace=False)
holdout_weights = scored.predict_weights()
Alternatively, ``design_matrix()``, ``predict_proba()``, and
``predict_weights()`` accept a ``data=`` argument so the holdout
workflow becomes a single line:
``fitted.predict_weights(data=holdout_bf)``.
Args:
target: Optional target population to set before fitting. If
provided, this method calls ``set_target(target, inplace=False)``
first, preserving immutability.
method: Adjustment method name (``"ipw"``, ``"cbps"``, ``"rake"``,
``"poststratify"``, ``"null"``) or a custom callable with the
weighting-method signature.
inplace: If True (default), mutate this object with the fitted
state and return ``self`` — matching sklearn's ``fit()``
convention. If False, return a new adjusted BalanceFrame
without modifying ``self``.
**kwargs: Keyword arguments forwarded to :meth:`adjust`.
Returns:
The fitted BalanceFrame — ``self`` when ``inplace=True``,
a new object when ``inplace=False``.
Raises:
ValueError: If no target is available and none is provided, if
``method`` is invalid, or if ``na_action='drop'`` is combined
with stored fit artifacts.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0.0, 1.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0.2, 0.8], "weight": [1.0, 1.0]}))
>>> adjusted = BalanceFrame(sample=resp, target=tgt).fit(method="null")
>>> bool(adjusted.is_adjusted)
True
Notes:
For the built-in IPW method, ``fit()`` enables
``store_fit_metadata=True`` and ``store_fit_matrices=True`` by
default so ``design_matrix()``/``predict_proba()``/``predict_weights()``
can consume fit-time artifacts. This may increase memory usage
for large inputs; pass these kwargs explicitly as ``False`` to
opt out.
"""
from balance.weighting_methods.ipw import ipw as built_in_ipw
resolved_method = self._resolve_adjustment_function(method)
if resolved_method is built_in_ipw:
kwargs.setdefault("store_fit_matrices", True)
kwargs.setdefault("store_fit_metadata", True)
na_action = kwargs.get("na_action", "add_indicator")
store_fit_matrices = bool(kwargs.get("store_fit_matrices"))
store_fit_metadata = bool(kwargs.get("store_fit_metadata"))
if na_action == "drop" and (store_fit_matrices or store_fit_metadata):
raise ValueError(
"BalanceFrame.fit(method='ipw', na_action='drop') is incompatible "
"with stored fit artifacts because dropped rows break index/shape "
"alignment for design_matrix/predict_proba. Use na_action='add_indicator', "
"or disable store_fit_matrices/store_fit_metadata."
)
if isinstance(target, (SampleFrame, BalanceFrame)):
result = self.set_target(target, inplace=False).adjust(
method=method, **kwargs
)
else:
result = self.adjust(target=target, method=method, **kwargs)
if not inplace:
return result
# Copy fitted state from result into self.
self._sf_sample = result._sf_sample
self._sf_sample_pre_adjust = result._sf_sample_pre_adjust
self._sf_target = result._sf_target
self._adjustment_model = result._adjustment_model
self._links = result._links
self._sync_sampleframe_state_from_responder(self._sf_sample)
return self
[docs]
def set_fitted_model(self, fitted: BalanceFrame, *, inplace: bool = True) -> Self:
"""Apply a fitted model from another BalanceFrame, producing a fully adjusted result.
This enables fit-then-apply workflows: fit on one BalanceFrame (e.g.,
a 20k subset) and apply the fitted model to another BalanceFrame
(e.g., the remaining 980k) with the same covariate schema. The
returned object is fully adjusted (``is_adjusted`` is True,
``model`` is set, ``summary()`` works with 3-way comparison).
**Workflow (inplace=False — returns new adjusted object):**
.. code-block:: python
fitted = train_bf.fit(method="ipw")
scored = holdout_bf.set_fitted_model(fitted, inplace=False)
scored.summary() # full diagnostics on holdout
**Workflow (inplace=True, default — mutates self):**
.. code-block:: python
holdout_bf.set_fitted_model(fitted)
holdout_bf.summary()
Currently supports IPW models. Other methods (CBPS, rake,
poststratify) will be supported once they store fit-time artifacts.
Args:
fitted: A BalanceFrame already adjusted with a supported method.
Its fitted model is used to compute holdout weights.
inplace: If True (default), mutate this object and return ``self``.
If False, return a new BalanceFrame with computed weights,
leaving ``self`` unchanged.
Returns:
A fully adjusted BalanceFrame with holdout weights applied.
``self`` when ``inplace=True``, a new object when ``inplace=False``.
Raises:
ValueError: If ``fitted`` has no stored model, if the model method
is not yet supported, or if covariate column names differ
between ``self`` and ``fitted``.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> train_resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0.0, 1.0], "weight": [1.0, 1.0]}))
>>> train_tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0.2, 0.8], "weight": [1.0, 1.0]}))
>>> holdout_resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [5, 6], "x": [0.1, 0.9], "weight": [1.0, 1.0]}))
>>> holdout_tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [7, 8], "x": [0.3, 0.7], "weight": [1.0, 1.0]}))
>>> train_bf = BalanceFrame(sample=train_resp, target=train_tgt)
>>> holdout_bf = BalanceFrame(sample=holdout_resp, target=holdout_tgt)
>>> fitted = train_bf.fit(method="ipw")
>>> scored = holdout_bf.set_fitted_model(fitted, inplace=False)
>>> scored.is_adjusted
True
>>> scored.model is not None
True
"""
from balance.weighting_methods.ipw import link_transform, weights_from_link
# --- Validation ---
if fitted.model is None:
raise ValueError(
"fitted must be an adjusted BalanceFrame with a stored model."
)
model = fitted.model
if not isinstance(model, dict):
raise ValueError("fitted must contain a valid adjustment model dict.")
method = model.get("method")
if method != "ipw":
raise ValueError(
"set_fitted_model() currently supports only IPW models. "
f"The fitted model uses method '{method}'. "
"Other methods (CBPS, rake, poststratify) will be supported "
"once they store fit-time artifacts."
)
if model.get("fit") is None:
raise ValueError("fitted IPW model is missing estimator information.")
if set(self._sf_sample.covars().df.columns) != set(
fitted._sf_sample.covars().df.columns
):
raise ValueError(
"self and fitted must have matching sample covariate column names."
)
if self._sf_target is not None and fitted._sf_target is not None:
if set(_assert_type(self._sf_target).covars().df.columns) != set(
_assert_type(fitted._sf_target).covars().df.columns
):
raise ValueError(
"self and fitted must have matching target covariate column names."
)
# --- Compute holdout weights ---
sample_matrix, _target_matrix = self._compute_ipw_matrices(model, source=self)
fit_model = _assert_type(model.get("fit"))
class_index = self._ipw_class_index(fit_model)
prob = np.asarray(fit_model.predict_proba(sample_matrix)[:, class_index])
link = link_transform(prob)
# Use holdout's own design weights
sample_weights = self._sf_sample.df_weights.iloc[:, 0]
target_weights = _assert_type(self._sf_target).df_weights.iloc[:, 0]
# Warn if holdout target weight sum differs >1% from training
training_target_weights = model.get("training_target_weights")
if isinstance(training_target_weights, pd.Series):
train_sum = training_target_weights.sum()
data_sum = target_weights.sum()
if train_sum > 0 and abs(train_sum - data_sum) / train_sum > 0.01:
logger.warning(
"set_fitted_model(): holdout target weights sum (%.2f) "
"differs from training target weights sum (%.2f). The "
"balance_classes correction and weight normalization will "
"use holdout's weights, which may produce different results "
"than the training fit.",
data_sum,
train_sum,
)
predicted = weights_from_link(
link=link,
balance_classes=bool(model.get("balance_classes", True)),
sample_weights=sample_weights,
target_weights=target_weights,
weight_trimming_mean_ratio=model.get("weight_trimming_mean_ratio"),
weight_trimming_percentile=model.get("weight_trimming_percentile"),
)
# --- Build the result ---
if inplace:
bf = self
else:
bf = type(self)._create(
sample=deepcopy(self._sf_sample),
target=deepcopy(self._sf_target),
)
if "target" in self._links:
bf._links["target"] = deepcopy(self._links["target"])
# Separate _sf_sample from _sf_sample_pre_adjust BEFORE mutating weights,
# so the unadjusted baseline preserves original design weights.
if not bf.is_adjusted and bf._sf_sample is bf._sf_sample_pre_adjust:
bf._sf_sample_pre_adjust = deepcopy(bf._sf_sample_pre_adjust)
# Apply computed weights
bf._sf_sample.set_weights(
pd.Series(predicted.values, index=bf._sf_sample.df.index),
use_index=True,
)
# Store the model and set adjustment state
bf._adjustment_model = dict(model)
bf._links["unadjusted"] = type(self)._create(
sample=bf._sf_sample_pre_adjust,
target=bf._sf_target,
)
bf._sync_sampleframe_state_from_responder(bf._sf_sample)
return bf
def _require_fitted_model(self) -> dict[str, Any]:
"""Return the adjustment model dict, or raise.
The model dict serves triple duty: configuration (formula, na_action,
one_hot_encoding), fitted artifacts (fit estimator, scaler, column
names), and cache (model_matrix_sample, sample_probability, etc.).
Treat nested values as read-only; cache updates must replace dict
keys, not mutate shared objects.
"""
model = self._adjustment_model
if model is None or not isinstance(model, dict):
raise ValueError(
"This operation requires an adjusted model. "
"Call fit()/adjust() first, or apply a model via "
"set_fitted_model()."
)
return model
def _require_ipw_model(self) -> dict[str, Any]:
"""Return the model dict, raising if it is not an IPW model with fit info."""
model = self._require_fitted_model()
if model.get("method") != "ipw":
raise ValueError(
"design_matrix() and predict_proba() currently support only IPW models. "
f"The current model uses method '{model.get('method')}'."
)
fit = model.get("fit")
columns = model.get("X_matrix_columns")
if fit is None or not isinstance(columns, list):
raise ValueError("IPW model metadata is missing fitted model information.")
return model
def _matrix_to_dataframe(
self,
matrix: Any,
index: pd.Index,
columns: list[str],
) -> pd.DataFrame:
if isinstance(matrix, pd.DataFrame):
return matrix.reindex(columns=columns)
if isinstance(matrix, np.ndarray):
return pd.DataFrame(matrix, index=index, columns=columns)
from scipy.sparse import spmatrix
if isinstance(matrix, spmatrix):
return pd.DataFrame.sparse.from_spmatrix(
matrix,
index=index,
columns=columns,
)
raise ValueError(
"Stored IPW fit-time model matrix is unavailable for this configuration."
)
def _align_to_index(
self,
data: pd.DataFrame | pd.Series,
index: pd.Index,
caller: str,
) -> pd.DataFrame | pd.Series:
"""Align a DataFrame or Series to the given index.
Handles unique indices (via reindex), reordered indices (via reindex),
and non-unique indices of equal length (via set_axis). Raises if
indices are incompatible.
"""
if data.index.is_unique and index.is_unique:
if len(data.index) == len(index) and not data.index.equals(index):
if not (data.index.isin(index).all() and index.isin(data.index).all()):
raise ValueError(
f"Stored IPW {caller} output index does not match the current "
"data index. Re-fit with BalanceFrame.fit(method='ipw') or "
"attach a model trained on matching rows."
)
return data.reindex(index)
if len(data.index) != len(index):
raise ValueError(
f"Stored IPW {caller} output cannot be aligned to the current "
"index because lengths differ. Re-fit with "
"BalanceFrame.fit(method='ipw') to refresh stored artifacts."
)
return data.set_axis(index, axis=0)
@staticmethod
def _ipw_class_index(fit_model: Any) -> int:
classes_attr = getattr(fit_model, "classes_", None)
if classes_attr is None:
raise ValueError(
"Stored IPW estimator is missing classes_ needed for predict_proba()."
)
classes = list(classes_attr)
if 1 not in classes:
raise ValueError("Stored IPW estimator is missing class label 1.")
return classes.index(1)
def _compute_ipw_matrices(
self,
model: dict[str, Any],
source: BalanceFrame | None = None,
) -> tuple[Any, Any]:
"""Compute IPW design matrices using stored model config.
Args:
model: The fitted model dict containing preprocessing config.
source: BalanceFrame to extract covariates from. When ``None``
(default), uses ``self``. When provided, uses ``source``'s
covariates — for the ``data=`` holdout path. Results are
only cached when ``source is None`` (via the caller).
"""
bf = source if source is not None else self
if source is None:
self._require_target()
elif bf._sf_target is None:
raise ValueError(
"data must have a target set when computing design matrices."
)
sample_covars = bf._sf_sample.covars().df.copy()
target_covars = _assert_type(bf._sf_target).covars().df.copy()
transformations = model.get("transformations", "default")
sample_covars, target_covars = balance_adjustment.apply_transformations(
(sample_covars, target_covars),
transformations=transformations,
)
columns: list[str] = _assert_type(model.get("X_matrix_columns"), list)
na_action = cast(str, model.get("na_action", "add_indicator"))
# Infer matrix_type from stored artifacts when not explicitly stored.
from scipy.sparse import spmatrix
matrix_type = model.get("fit_matrix_type")
if matrix_type is None:
fit_sample_matrix = model.get("model_matrix_sample")
if isinstance(fit_sample_matrix, spmatrix):
matrix_type = "sparse"
elif isinstance(fit_sample_matrix, np.ndarray):
matrix_type = "dense"
elif isinstance(fit_sample_matrix, pd.DataFrame):
matrix_type = "dataframe"
result = build_design_matrix(
sample_covars,
target_covars,
use_model_matrix=bool(model.get("use_model_matrix", True)),
formula=model.get("formula"),
one_hot_encoding=bool(model.get("one_hot_encoding", False)),
na_action=na_action,
project_to_columns=columns,
fit_scaler=model.get("fit_scaler"),
fit_penalties_skl=model.get("fit_penalties_skl"),
matrix_type=matrix_type,
)
combined_matrix = result["combined_matrix"]
sample_n = result["sample_n"]
return combined_matrix[:sample_n], combined_matrix[sample_n:]
@staticmethod
def _is_artifact_stale(
artifact: Any,
stored_idx: pd.Index,
current_idx: pd.Index,
) -> bool:
"""Check whether a stored artifact's index is stale vs current data."""
if artifact is None:
return True
artifact_len = getattr(artifact, "shape", [0])[0]
if artifact_len != len(current_idx):
return True
if len(stored_idx) == len(current_idx):
same_set = (
stored_idx.isin(current_idx).all()
and current_idx.isin(stored_idx).all()
)
if not same_set:
return True
return False
def _ensure_fresh_ipw_artifacts(
self,
model: dict[str, Any],
side: Literal["sample", "target"],
) -> None:
"""Recompute and cache IPW matrices + predictions for one side if stale.
Checks whether the stored matrix and predictions for the given side
match the current data index. If stale, recomputes both matrices
(sample + target — required for correct one-hot encoding) and
predictions, then caches the results in the model dict.
"""
if side == "sample":
current_idx = self._sf_sample.df.index
else:
self._require_target()
current_idx = _assert_type(self._sf_target).df.index
stored_idx = pd.Index(model.get(f"{side}_index", current_idx))
matrix = model.get(f"model_matrix_{side}")
probability = model.get(f"{side}_probability")
matrix_stale = self._is_artifact_stale(matrix, stored_idx, current_idx)
prob_stale = self._is_artifact_stale(probability, stored_idx, current_idx)
if not matrix_stale and not prob_stale:
return
# Recompute matrices (always produces both sides).
sample_matrix, target_matrix = self._compute_ipw_matrices(model)
# Cache both matrices — we already paid the cost of computing them.
model["model_matrix_sample"] = sample_matrix
model["model_matrix_target"] = target_matrix
# Only update the index for the side being refreshed. Writing the
# other side's index here would cause the staleness check to skip
# recomputation on the next call, returning stale predictions.
model[f"{side}_index"] = current_idx
recomputed_matrix = sample_matrix if side == "sample" else target_matrix
# Recompute predictions if stale.
if prob_stale:
from balance.weighting_methods.ipw import link_transform
fit_model = _assert_type(model.get("fit"))
class_index = self._ipw_class_index(fit_model)
prob = np.asarray(
fit_model.predict_proba(recomputed_matrix)[:, class_index]
)
link = link_transform(prob)
model[f"{side}_probability"] = prob
model[f"{side}_link"] = link
model[f"{side}_index"] = current_idx
def _validate_data_covariates(self, data: BalanceFrame) -> None:
"""Validate that ``data`` has matching covariate columns to self."""
if set(data._sf_sample.covars().df.columns) != set(
self._sf_sample.covars().df.columns
):
raise ValueError(
"data and self must have matching sample covariate column names."
)
if data._sf_target is not None and self._sf_target is not None:
if set(_assert_type(data._sf_target).covars().df.columns) != set(
_assert_type(self._sf_target).covars().df.columns
):
raise ValueError(
"data and self must have matching target covariate column names."
)
@overload
def design_matrix( # noqa: E704
self,
on: Literal["sample"],
*,
data: BalanceFrame | None = ...,
) -> pd.DataFrame: ...
@overload
def design_matrix( # noqa: E704
self,
on: Literal["target"],
*,
data: BalanceFrame | None = ...,
) -> pd.DataFrame: ...
@overload
def design_matrix( # noqa: E704
self,
on: Literal["both"] = ...,
*,
data: BalanceFrame | None = ...,
) -> tuple[pd.DataFrame, pd.DataFrame]: ...
[docs]
def design_matrix(
self,
on: Literal["sample", "target", "both"] = "both",
*,
data: BalanceFrame | None = None,
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""Return the IPW model's design matrices.
Returns the model matrices (feature matrices) built by the stored
preprocessing pipeline — after formula expansion, one-hot encoding,
NA indicator addition, scaling, and penalty weighting.
When ``data`` is provided, the stored preprocessing is applied to
``data``'s covariates and the result is returned without caching.
When ``data`` is None (default), stored/cached matrices for this
object's own data are returned (original behavior).
Args:
on: Which population's matrix to return. ``"sample"`` returns the
respondent matrix, ``"target"`` returns the target matrix, and
``"both"`` returns ``(sample_matrix, target_matrix)``.
data: An optional BalanceFrame whose covariates are transformed
using this object's stored preprocessing pipeline. The
``data`` BalanceFrame does not need to be adjusted — it just
provides covariates. Must have matching covariate column names.
Returns:
A model-matrix DataFrame, or a tuple of two DataFrames when
``on="both"``.
Raises:
ValueError: If the object is not IPW-adjusted, if target is missing
for ``on in {"target", "both"}``, if recomputation of sample-side
artifacts is required but no target is available, if ``on`` is
invalid, or if ``data`` has mismatched covariate columns.
Notes:
When ``data`` is None and stored fit artifacts are stale for the
current rows (e.g., after ``set_fitted_model()``), this method
recomputes and caches refreshed matrices. That cache update is an
intentional in-memory mutation. When ``data`` is provided, no
caching occurs.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0.0, 1.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0.2, 0.8], "weight": [1.0, 1.0]}))
>>> adjusted = BalanceFrame(sample=resp, target=tgt).fit(method="ipw")
>>> x_s, x_t = adjusted.design_matrix(on="both")
>>> x_s.shape[0], x_t.shape[0]
(2, 2)
"""
if on not in ("sample", "target", "both"):
raise ValueError("on must be one of: 'sample', 'target', 'both'")
model = self._require_ipw_model()
columns: list[str] = _assert_type(model.get("X_matrix_columns"), list)
# --- data= path: compute from external data, no caching ---
if data is not None:
self._validate_data_covariates(data)
if on in ("target", "both") and data._sf_target is None:
raise ValueError(
"data must have a target set for on='target' or on='both'."
)
sample_matrix, target_matrix = self._compute_ipw_matrices(
model, source=data
)
sample_df_result = self._matrix_to_dataframe(
sample_matrix, data._sf_sample.df.index, columns
)
target_df_result = (
self._matrix_to_dataframe(
target_matrix,
_assert_type(data._sf_target).df.index,
columns,
)
if data._sf_target is not None
else None
)
if on == "sample":
return sample_df_result
if on == "target":
return _assert_type(target_df_result)
return (sample_df_result, _assert_type(target_df_result))
# --- default path: use stored/cached artifacts ---
sample_df: pd.DataFrame | None = None
target_df: pd.DataFrame | None = None
if on in ("sample", "both"):
if model.get("model_matrix_sample") is None:
raise ValueError(
"IPW model is missing fit-time sample matrix. Call "
"BalanceFrame.fit(method='ipw') or run ipw(..., "
"store_fit_matrices=True) before using design_matrix(on='sample'/'both')."
)
self._ensure_fresh_ipw_artifacts(model, "sample")
sample_idx = pd.Index(model.get("sample_index", self._sf_sample.df.index))
sample_df = cast(
pd.DataFrame,
self._align_to_index(
self._matrix_to_dataframe(
model["model_matrix_sample"], sample_idx, columns
),
self._sf_sample.df.index,
caller="design_matrix()",
),
)
if on in ("target", "both"):
self._require_target()
if model.get("model_matrix_target") is None:
raise ValueError(
"IPW model is missing fit-time target matrix. Call "
"BalanceFrame.fit(method='ipw') or run ipw(..., "
"store_fit_matrices=True) before using design_matrix(on='target'/'both')."
)
self._ensure_fresh_ipw_artifacts(model, "target")
current_target_idx = _assert_type(self._sf_target).df.index
target_idx = pd.Index(model.get("target_index", current_target_idx))
target_df = cast(
pd.DataFrame,
self._align_to_index(
self._matrix_to_dataframe(
model["model_matrix_target"], target_idx, columns
),
current_target_idx,
caller="design_matrix()",
),
)
if on == "sample":
return _assert_type(sample_df)
if on == "target":
return _assert_type(target_df)
return (_assert_type(sample_df), _assert_type(target_df))
@overload
def predict_proba( # noqa: E704
self,
on: Literal["sample"],
output: Literal["probability", "link"] = ...,
*,
data: BalanceFrame | None = ...,
) -> pd.Series: ...
@overload
def predict_proba( # noqa: E704
self,
on: Literal["target"],
output: Literal["probability", "link"] = ...,
*,
data: BalanceFrame | None = ...,
) -> pd.Series: ...
@overload
def predict_proba( # noqa: E704
self,
on: Literal["both"] = ...,
output: Literal["probability", "link"] = ...,
*,
data: BalanceFrame | None = ...,
) -> tuple[pd.Series, pd.Series]: ...
[docs]
def predict_proba(
self,
on: Literal["sample", "target", "both"] = "both",
output: Literal["probability", "link"] = "probability",
*,
data: BalanceFrame | None = None,
) -> pd.Series | tuple[pd.Series, pd.Series]:
"""Return IPW propensity scores.
Returns the propensity scores (predicted probabilities of being in
the sample vs target) from the fitted IPW model. A target row with
high propensity is well-represented in the sample; a low score
indicates underrepresentation.
When ``data`` is provided, the stored model is applied to ``data``'s
covariates and fresh predictions are returned without caching. When
``data`` is None (default), stored/cached predictions for this
object's own data are returned (original behavior).
Args:
on: Which population to predict on (``"sample"``, ``"target"``,
or ``"both"``).
output: Output scale. ``"probability"`` returns class-1 propensity
probabilities. ``"link"`` returns logit-transformed values.
data: An optional BalanceFrame whose covariates are scored using
this object's stored model. Must have matching covariate
column names. The ``data`` BalanceFrame needs a target for
``on="target"`` or ``on="both"``.
Returns:
A prediction Series, or a tuple of two Series when ``on="both"``.
Raises:
ValueError: If the object is not IPW-adjusted, if target is missing
for ``on in {"target", "both"}``, if recomputation of sample-side
predictions is required but no target is available, if ``on`` is
invalid, or if ``data`` has mismatched covariate columns.
Notes:
When ``data`` is None and stored fit-time predictions are stale
for the current rows, this method may recompute and cache refreshed
probabilities/links. When ``data`` is provided, no caching occurs.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0.0, 1.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0.2, 0.8], "weight": [1.0, 1.0]}))
>>> adjusted = BalanceFrame(sample=resp, target=tgt).fit(method="ipw")
>>> p = adjusted.predict_proba(on="target", output="probability")
>>> int(p.shape[0])
2
"""
if output not in ("probability", "link"):
raise ValueError("output must be one of: 'probability', 'link'")
if on not in ("sample", "target", "both"):
raise ValueError("on must be one of: 'sample', 'target', 'both'")
model = self._require_ipw_model()
# --- data= path: compute from external data, no caching ---
if data is not None:
from balance.weighting_methods.ipw import link_transform
self._validate_data_covariates(data)
if on in ("target", "both") and data._sf_target is None:
raise ValueError(
"data must have a target set for on='target' or on='both'."
)
sample_matrix, target_matrix = self._compute_ipw_matrices(
model, source=data
)
fit_model = _assert_type(model.get("fit"))
class_index = self._ipw_class_index(fit_model)
sample_series_result: pd.Series | None = None
target_series_result: pd.Series | None = None
if on in ("sample", "both"):
prob = np.asarray(
fit_model.predict_proba(sample_matrix)[:, class_index]
)
if output == "link":
values_arr = link_transform(prob)
else:
values_arr = prob
sample_series_result = pd.Series(
values_arr, index=data._sf_sample.df.index
)
if on in ("target", "both"):
prob = np.asarray(
fit_model.predict_proba(target_matrix)[:, class_index]
)
if output == "link":
values_arr = link_transform(prob)
else:
values_arr = prob
target_series_result = pd.Series(
values_arr, index=_assert_type(data._sf_target).df.index
)
if on == "sample":
return _assert_type(sample_series_result)
if on == "target":
return _assert_type(target_series_result)
return (
_assert_type(sample_series_result),
_assert_type(target_series_result),
)
# --- default path: use stored/cached artifacts ---
sample_series: pd.Series | None = None
target_series: pd.Series | None = None
for side in ("sample", "target"):
if side == "sample" and on not in ("sample", "both"):
continue
if side == "target" and on not in ("target", "both"):
continue
if side == "target":
self._require_target()
stored_key = (
f"{side}_probability" if output == "probability" else f"{side}_link"
)
values = model.get(stored_key)
if not isinstance(values, np.ndarray):
raise ValueError(
f"IPW model is missing fit-time {side} predictions for predict_proba(). "
"Call BalanceFrame.fit(method='ipw') or run ipw(..., "
f"store_fit_metadata=True) before using predict_proba(on='{side}'/'both')."
)
self._ensure_fresh_ipw_artifacts(model, side) # type: ignore[arg-type]
# Re-read after potential refresh.
values = model.get(stored_key)
if side == "sample":
current_idx = self._sf_sample.df.index
else:
current_idx = _assert_type(self._sf_target).df.index
stored_idx = pd.Index(model.get(f"{side}_index", current_idx))
series = cast(
pd.Series,
self._align_to_index(
pd.Series(_assert_type(values), index=stored_idx),
current_idx,
caller="predict_proba()",
),
)
if side == "sample":
sample_series = series
else:
target_series = series
if on == "sample":
return _assert_type(sample_series)
if on == "target":
return _assert_type(target_series)
return (_assert_type(sample_series), _assert_type(target_series))
[docs]
def predict_weights(
self,
*,
data: BalanceFrame | None = None,
) -> pd.Series:
"""Predict responder weights from the fitted model's artifacts.
Reconstructs adjusted survey weights from stored fit-time artifacts
(propensity links, design weights, class balancing, trimming
parameters). On the fitted object itself, the result is numerically
equivalent to ``self.weights().df`` (within floating-point tolerance)
and serves as a validation that the stored artifacts are sufficient
to reproduce the adjustment.
When ``data`` is provided, computes weights for ``data``'s sample
using the stored model, without caching. This is the one-liner
alternative to the ``set_fitted_model`` workflow::
fitted.predict_weights(data=holdout_bf)
When ``data`` is None (default), uses this object's own data
(original behavior).
Dispatches by the adjustment method stored in the model dict:
- **IPW**: uses stored fit-time metadata (links, class balancing,
trimming, and design weights) to reproduce fitted responder weights.
- **Other methods**: not yet supported — will raise with guidance.
Args:
data: An optional BalanceFrame whose sample covariates are scored
using this object's stored model. Must have matching covariate
column names and a target set.
Returns:
A Series of predicted responder weights.
Raises:
ValueError: If no fitted model is available, if the method is
unsupported, if required target data is missing, or if
``data`` has mismatched covariate columns.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0.0, 1.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0.2, 0.8], "weight": [1.0, 1.0]}))
>>> adjusted = BalanceFrame(sample=resp, target=tgt).fit(method="ipw")
>>> w = adjusted.predict_weights()
>>> int(w.shape[0])
2
"""
# NOTE: The data= path intentionally duplicates the weight computation
# from set_fitted_model() rather than calling it. set_fitted_model()
# could be reused (scored = data.set_fitted_model(self, inplace=False);
# return scored.weight_series), but:
# 1. It does unnecessary work (deepcopy, _links, _adjustment_model)
# that would be immediately discarded.
# 2. design_matrix(data=...) and predict_proba(data=...) cannot use
# the same pattern — they need intermediate results (matrices,
# probabilities), not final weights. Keeping all three data= paths
# parallel is clearer.
if data is not None:
from balance.weighting_methods.ipw import link_transform, weights_from_link
self._validate_data_covariates(data)
model = self._require_fitted_model()
method = model.get("method")
if method != "ipw":
raise ValueError(
f"predict_weights(data=...) is not yet supported for method '{method}'. "
"Currently only 'ipw' is supported."
)
fit_obj = model.get("fit")
columns = model.get("X_matrix_columns")
if fit_obj is None or not isinstance(columns, list):
raise ValueError(
"IPW model metadata is missing fitted model information."
)
sample_matrix, _target_matrix = self._compute_ipw_matrices(
model, source=data
)
class_index = self._ipw_class_index(fit_obj)
prob = np.asarray(fit_obj.predict_proba(sample_matrix)[:, class_index])
link = link_transform(prob)
if data._sf_target is None:
raise ValueError("data must have a target set for predict_weights().")
sample_weights = data._sf_sample.df_weights.iloc[:, 0]
target_weights = _assert_type(data._sf_target).df_weights.iloc[:, 0]
training_target_weights = model.get("training_target_weights")
if isinstance(training_target_weights, pd.Series):
train_sum = training_target_weights.sum()
data_sum = target_weights.sum()
if train_sum > 0 and abs(train_sum - data_sum) / train_sum > 0.01:
logger.warning(
"predict_weights(data=...): data's target weights sum (%.2f) "
"differs from training target weights sum (%.2f). The "
"balance_classes correction and weight normalization will "
"use data's weights, which may produce different results "
"than the training fit.",
target_weights.sum(),
training_target_weights.sum(),
)
predicted = weights_from_link(
link=link,
balance_classes=bool(model.get("balance_classes", True)),
sample_weights=sample_weights,
target_weights=target_weights,
weight_trimming_mean_ratio=model.get("weight_trimming_mean_ratio"),
weight_trimming_percentile=model.get("weight_trimming_percentile"),
)
weight_name = getattr(data._sf_sample.weight_series, "name", None)
return pd.Series(predicted.values, index=data._sf_sample.df.index).rename(
weight_name
)
self._require_target()
model = self._require_fitted_model()
method = model.get("method")
if method == "ipw":
return self._predict_weights_ipw(model)
# TODO: Add predict_weights dispatch for other methods.
#
# Step 3 — Store fit artifacts in each weighting method:
# - cbps.py: Save standardization params (model_matrix_mean,
# model_matrix_std) and beta_optimal in the returned model dict.
# These are currently local variables discarded after fitting.
# Add a `store_fit_metadata: bool = False` parameter mirroring
# ipw.py. ~15 lines in cbps.py.
# - poststratify.py: Save the cell-ratio table
# (`combined["weight"]` at line 194), the variable list, and
# na_action in the returned model dict. Currently only
# `{"method": "poststratify"}` is returned. ~15 lines.
# - rake.py: Save the fitted contingency table (`m_fit`), the
# variable lists, and category-to-index mappings. Currently
# `m_fit` is discarded after per-row weight assignment.
# More complex due to N-dimensional array indexing. ~25 lines.
#
# Step 4 — Implement _predict_weights_* for each method:
# - _predict_weights_cbps(model): Rebuild the model matrix from
# current covariates, standardize with stored mean/std, compute
# logit_truncated(X @ beta_optimal), convert propensity to
# weights via compute_pseudo_weights_from_logit_probs. ~40 lines.
# - _predict_weights_poststratify(model): Join current sample rows
# on the stored cell-ratio table by cell variables, multiply
# ratio by design weight, normalize to target total. ~20 lines.
# - _predict_weights_rake(model): Look up each row's cell in the
# stored fitted N-dimensional contingency table, compute weight
# ratio (fitted_cell / original_cell), multiply by design
# weight. Same logic as rake.py lines 229-239. ~30 lines.
#
# Note: shared preprocessing is already extracted into
# build_design_matrix() in utils/model_matrix.py, used by both
# _compute_ipw_matrices() and ipw().
raise ValueError(
f"predict_weights() is not yet supported for method '{method}'. "
"Currently only 'ipw' is supported. Use adjust() to obtain "
"weights directly for other methods."
)
def _resolve_ipw_link(self, model: dict[str, Any]) -> np.ndarray:
"""Resolve sample link values from stored artifacts or recomputation."""
model_link = model.get("sample_link")
current_sample_idx = self._sf_sample.df.index
model_sample_idx = pd.Index(model.get("sample_index", current_sample_idx))
if (
isinstance(model_link, np.ndarray)
and model_link.shape[0] == len(current_sample_idx)
and model_sample_idx.equals(current_sample_idx)
):
return model_link
return self.predict_proba(on="sample", output="link").to_numpy()
def _resolve_design_weights(
self,
model: dict[str, Any],
link: np.ndarray,
) -> tuple[pd.Series, pd.Series]:
"""Resolve sample and target design weights for predict_weights().
Uses stored training weights when available and compatible; falls
back to current design weights with a warning otherwise.
"""
current_sample_weights = self._sf_sample.df_weights.iloc[:, 0]
current_target_weights = _assert_type(self._sf_target).df_weights.iloc[:, 0]
sample_weights = model.get("training_sample_weights")
if (
not isinstance(sample_weights, pd.Series)
or len(sample_weights) != len(link)
or not sample_weights.index.equals(current_sample_weights.index)
):
logger.warning(
"Falling back to current sample design weights in predict_weights(); "
"stored training_sample_weights are unavailable or incompatible."
)
sample_weights = current_sample_weights
target_weights = model.get("training_target_weights")
if (
not isinstance(target_weights, pd.Series)
or len(target_weights) != len(current_target_weights)
or not target_weights.index.equals(current_target_weights.index)
):
logger.warning(
"Falling back to current target design weights in predict_weights(); "
"stored training_target_weights are unavailable or incompatible."
)
target_weights = current_target_weights
return sample_weights, target_weights
def _predict_weights_ipw(self, model: dict[str, Any]) -> pd.Series:
"""IPW-specific weight prediction from stored fit artifacts."""
from balance.weighting_methods.ipw import weights_from_link
fit = model.get("fit")
columns = model.get("X_matrix_columns")
if fit is None or not isinstance(columns, list):
raise ValueError("IPW model metadata is missing fitted model information.")
link = self._resolve_ipw_link(model)
sample_weights, target_weights = self._resolve_design_weights(model, link)
predicted = weights_from_link(
link=link,
balance_classes=bool(model.get("balance_classes", True)),
sample_weights=sample_weights,
target_weights=target_weights,
weight_trimming_mean_ratio=model.get("weight_trimming_mean_ratio"),
weight_trimming_percentile=model.get("weight_trimming_percentile"),
)
current_sample_idx = self._sf_sample.df.index
model_sample_idx = pd.Index(model.get("sample_index", current_sample_idx))
sample_idx = model_sample_idx
if len(sample_idx) != len(predicted) or not sample_idx.equals(
current_sample_idx
):
sample_idx = current_sample_idx
weight_name = getattr(_assert_type(self.weight_series), "name", None)
return cast(
pd.Series,
self._align_to_index(
pd.Series(predicted.values, index=sample_idx),
self._sf_sample.df.index,
caller="predict_weights()",
),
).rename(weight_name)
@property
def model(self) -> dict[str, Any] | None:
"""The adjustment model dictionary, or None if not adjusted.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.model is None
True
"""
return self._adjustment_model
# --- Conversion ---
[docs]
@classmethod
def from_sample(cls, sample: Any) -> BalanceFrame:
"""Convert a :class:`~balance.sample_class.Sample` to a BalanceFrame.
The Sample must have a target set (via ``Sample.set_target``). If
the Sample is adjusted, the adjustment state (unadjusted responders,
model) is preserved.
Args:
sample: A :class:`~balance.sample_class.Sample` instance with
a target.
Returns:
BalanceFrame: A new BalanceFrame mirroring the Sample's data,
target, and adjustment state.
Raises:
TypeError: If *sample* is not a Sample instance.
ValueError: If *sample* does not have a target set.
Examples:
>>> import pandas as pd
>>> from balance.sample_class import Sample
>>> from balance.balance_frame import BalanceFrame
>>> s = Sample.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> t = Sample.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame.from_sample(s.set_target(t))
>>> bf.is_adjusted
False
"""
# Lazy import: sample_class ↔ balance_frame have a circular dependency.
from balance.sample_class import Sample
if not isinstance(sample, Sample):
raise TypeError(
f"'sample' must be a Sample instance, got {type(sample).__name__}"
)
if not sample.has_target():
raise ValueError(
"Sample must have a target set. "
"Use sample.set_target(target) before calling BalanceFrame.from_sample()."
)
responders_sf = SampleFrame.from_sample(sample)
target_sf = SampleFrame.from_sample(sample._links["target"])
bf = cls._create(sample=responders_sf, target=target_sf)
if sample.is_adjusted():
# Set unadjusted to a DIFFERENT SampleFrame so is_adjusted returns True
bf._sf_sample_pre_adjust = SampleFrame.from_sample(
sample._links["unadjusted"]
)
bf._adjustment_model = sample.model
return bf
[docs]
def to_sample(self) -> Any:
"""Convert this BalanceFrame back to a :class:`~balance.sample_class.Sample`.
Reconstructs a Sample with the responder data and target set. If
this BalanceFrame is adjusted, the returned Sample will also be
adjusted — ``is_adjusted()`` returns True, ``has_target()`` returns
True, and the original (unadjusted) weights are preserved via the
``"unadjusted"`` link.
Returns:
Sample: A Sample mirroring this BalanceFrame's data, target,
and adjustment state.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2, 3], "x": [10.0, 20.0, 30.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [4, 5, 6], "x": [15.0, 25.0, 35.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> s = bf.to_sample()
>>> s.has_target()
True
"""
# Lazy import: sample_class ↔ balance_frame have a circular dependency.
from balance.sample_class import Sample
target = self._sf_target
if target is None:
raise ValueError(
"Cannot convert to Sample: BalanceFrame has no target set."
)
resp_sample = Sample.from_frame(
self._sf_sample._df,
id_column=self._sf_sample._id_column_name,
weight_column=self._sf_sample._weight_column_name,
outcome_columns=self._sf_sample.outcome_columns or None,
ignored_columns=self._sf_sample.ignored_columns or None,
standardize_types=False,
)
target_sample = Sample.from_frame(
target._df,
id_column=target._id_column_name,
weight_column=target._weight_column_name,
outcome_columns=target.outcome_columns or None,
ignored_columns=target.ignored_columns or None,
standardize_types=False,
)
result = resp_sample.set_target(target_sample)
if self.is_adjusted and self._sf_sample_pre_adjust is not None:
unadj_sf = SampleFrame.from_frame(
self._sf_sample_pre_adjust._df,
id_column=self._sf_sample_pre_adjust._id_column_name,
weight_column=self._sf_sample_pre_adjust._weight_column_name,
outcome_columns=self._sf_sample_pre_adjust.outcome_columns or None,
ignored_columns=self._sf_sample_pre_adjust.ignored_columns or None,
standardize_types=False,
)
# pyre-ignore[16]: Sample gains this attr via BalanceFrame inheritance (diff 14.3)
result._sf_sample_pre_adjust = unadj_sf
# pyre-ignore[16]: Sample gains _links via BalanceFrame inheritance (diff 14.3)
result._links["unadjusted"] = unadj_sf
# pyre-ignore[16]: Sample gains this attr via BalanceFrame inheritance (diff 14.3)
result._adjustment_model = self._adjustment_model
return result
# --- BalanceDF integration ---
def _build_links_dict(self) -> dict[str, BalanceDFSource]:
"""Build a ``_links`` dict matching Sample._links structure.
Creates a dict mapping link names to SampleFrame instances for the
target and (if adjusted) the unadjusted responders so that
``BalanceDF._balancedf_child_from_linked_samples`` can walk the
links just as it does for the old ``Sample`` class.
Returns:
dict: Mapping of link names to BalanceDFSource instances.
"""
links: dict[str, BalanceDFSource] = {}
if self._sf_target is not None:
links["target"] = self._sf_target
if self.is_adjusted:
links["unadjusted"] = self._sf_sample_pre_adjust
return links
[docs]
def covars(self, formula: str | list[str] | None = None) -> Any:
"""Return a :class:`~balance.balancedf_class.BalanceDFCovars` for the responders.
The returned object carries linked target (and unadjusted, if
adjusted) views so that methods like ``.mean()`` and ``.asmd()``
automatically include comparisons across sources.
Args:
formula: Optional formula string (or list) for model matrix
construction. Passed through to BalanceDFCovars.
Returns:
BalanceDFCovars: Covariate view with linked sources.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.covars().df.columns.tolist()
['x']
"""
from balance.balancedf_class import BalanceDFCovars, BalanceDFSource
return BalanceDFCovars(
cast(BalanceDFSource, self),
links=self._build_links_dict(),
formula=formula,
)
[docs]
def weights(self) -> Any:
"""Return a :class:`~balance.balancedf_class.BalanceDFWeights` for the responders.
The returned object carries linked target (and unadjusted, if
adjusted) views for comparative weight analysis.
Returns:
BalanceDFWeights: Weight view with linked sources.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 2.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.weights().df.columns.tolist()
['weight']
"""
from balance.balancedf_class import BalanceDFSource, BalanceDFWeights
# Pass self (not _sf_sample) so that r_indicator and other methods
# that access self._sample._links find the BalanceFrame's _links.
return BalanceDFWeights(
cast(BalanceDFSource, self), links=self._build_links_dict()
)
[docs]
def outcomes(self) -> Any | None:
"""Return a :class:`~balance.balancedf_class.BalanceDFOutcomes`, or None.
Returns ``None`` if the responder SampleFrame has no outcome columns.
Returns:
BalanceDFOutcomes or None: Outcome view with linked sources,
or ``None`` if no outcomes are defined.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0],
... "y": [1.0, 0.0], "weight": [1.0, 1.0]}),
... outcome_columns=["y"])
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.outcomes().df.columns.tolist()
['y']
"""
if not self._sf_sample.outcome_columns:
return None
from balance.balancedf_class import BalanceDFOutcomes, BalanceDFSource
return BalanceDFOutcomes(
cast(BalanceDFSource, self), links=self._build_links_dict()
)
# --- Summary & diagnostics ---
def _design_effect_diagnostics(
self,
n_rows: int | None = None,
) -> tuple[float | None, float | None, float | None]:
"""Compute design effect, ESS, and ESSP from the responder weights.
Args:
n_rows: Optional row count to use for scaling. Defaults to the
sample size when not provided.
Returns:
tuple: ``(design_effect, effective_sample_size,
effective_sample_proportion)``. All ``None`` if the design
effect cannot be computed.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0, 1], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0, 1], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf._design_effect_diagnostics()
(1.0, 2.0, 1.0)
"""
if n_rows is None:
n_rows = len(self._sf_sample)
try:
de = weights_stats.design_effect(self._sf_sample.df_weights.iloc[:, 0])
except (TypeError, ValueError, ZeroDivisionError) as exc:
logger.debug("Unable to compute design effect: %s", exc)
return None, None, None
if de is None or not np.isfinite(de):
return None, None, None
effective_sample_size = None
effective_sample_proportion = None
if n_rows and de != 0:
effective_sample_size = n_rows / de
effective_sample_proportion = effective_sample_size / n_rows
return float(de), effective_sample_size, effective_sample_proportion
def _quick_adjustment_details(
self,
n_rows: int | None = None,
de: float | None = None,
ess: float | None = None,
essp: float | None = None,
) -> list[str]:
"""Collect quick-to-compute adjustment diagnostics for display.
Args:
de: Pre-computed design effect, or ``None`` to compute lazily.
ess: Pre-computed effective sample size, or ``None`` to compute
lazily.
essp: Pre-computed effective sample proportion, or ``None`` to
compute lazily.
Returns:
list[str]: Human-readable lines describing adjustment method,
trimming configuration, and weight diagnostics.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [0, 1], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [0, 1], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> adjusted = bf.adjust(method="null")
>>> "method: null" in adjusted._quick_adjustment_details()
True
"""
details: list[str] = []
model = self.model
if isinstance(model, dict):
method = model.get("method")
if isinstance(method, str):
details.append(f"method: {method}")
trimming_mean_ratio = model.get("weight_trimming_mean_ratio")
if trimming_mean_ratio is not None:
details.append(f"weight trimming mean ratio: {trimming_mean_ratio}")
trimming_percentile = model.get("weight_trimming_percentile")
if trimming_percentile is not None:
details.append(f"weight trimming percentile: {trimming_percentile}")
if de is None:
de, ess, essp = self._design_effect_diagnostics(n_rows)
if de is not None:
details.append(f"design effect (Deff): {de:.3f}")
if essp is not None:
details.append(f"effective sample size proportion (ESSP): {essp:.3f}")
if ess is not None:
details.append(f"effective sample size (ESS): {ess:.1f}")
return details
[docs]
def summary(self) -> str:
"""Consolidated summary of covariate balance, weight health, and outcomes.
Produces a multi-line summary combining covariate ASMD / KLD
diagnostics, weight design effect, and outcome means. Delegates to
:func:`~balance.summary_utils._build_summary` after computing the
necessary intermediate values.
When no target is set, returns a minimal summary with weight
diagnostics and outcome means only.
Returns:
str: A human-readable multi-line summary string.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2, 3, 4], "x": [0, 1, 1, 0],
... "weight": [1.0, 2.0, 1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [5, 6, 7, 8], "x": [0, 0, 1, 1],
... "weight": [1.0, 1.0, 1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> adjusted = bf.adjust(method="null")
>>> "Covariate diagnostics:" in adjusted.summary()
True
"""
if not self.has_target() and not self.is_adjusted:
# No target: minimal summary (weight diagnostics + outcomes only)
de, ess, essp = self._design_effect_diagnostics(self._df.shape[0])
outcome_means = None
if self._outcome_columns is not None:
outcome_means = self.outcomes().mean()
return _build_summary(
is_adjusted=False,
has_target=False,
covars_asmd=None,
covars_kld=None,
asmd_improvement_pct=None,
quick_adjustment_details=[],
design_effect=de,
effective_sample_size=ess,
effective_sample_proportion=essp,
model_dict=self.model,
outcome_means=outcome_means,
)
covars_asmd = self.covars().asmd()
covars_kld = self.covars().kld(aggregate_by_main_covar=True)
asmd_improvement_pct = None
if self.is_adjusted:
asmd_improvement_pct = 100 * self.covars().asmd_improvement()
de, ess, essp = self._design_effect_diagnostics()
quick_adjustment_details: list[str] = []
if self.is_adjusted:
quick_adjustment_details = self._quick_adjustment_details(
de=de, ess=ess, essp=essp
)
outcome_means = None
outcomes = self.outcomes()
if outcomes is not None:
outcome_means = outcomes.mean()
return _build_summary(
is_adjusted=bool(self.is_adjusted),
has_target=True,
covars_asmd=covars_asmd,
covars_kld=covars_kld,
asmd_improvement_pct=asmd_improvement_pct,
quick_adjustment_details=quick_adjustment_details,
design_effect=de,
effective_sample_size=ess,
effective_sample_proportion=essp,
model_dict=self.model,
outcome_means=outcome_means,
)
[docs]
def diagnostics(
self,
weights_impact_on_outcome_method: str | None = "t_test",
weights_impact_on_outcome_conf_level: float = 0.95,
) -> pd.DataFrame:
"""Table of diagnostics about the adjusted BalanceFrame.
Produces a DataFrame with columns ``["metric", "val", "var"]``
containing size information, weight diagnostics, model details,
covariate ASMD, and optionally outcome-weight impact statistics.
Delegates to :func:`~balance.summary_utils._build_diagnostics`.
Args:
weights_impact_on_outcome_method: Method for
computing outcome-weight impact. Pass ``None`` to skip.
Defaults to ``"t_test"``.
weights_impact_on_outcome_conf_level: Confidence level
for outcome impact intervals. Defaults to ``0.95``.
Returns:
pd.DataFrame: Diagnostics table with columns
``["metric", "val", "var"]``.
Raises:
ValueError: If this BalanceFrame has not been adjusted.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": ["1", "2"], "x": [0, 1],
... "weight": [1.0, 2.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": ["3", "4"], "x": [0, 1],
... "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> adjusted = bf.adjust(method="null")
>>> adjusted.diagnostics().columns.tolist()
['metric', 'val', 'var']
"""
logger.info("Starting computation of diagnostics of the fitting")
self._require_adjusted()
outcome_columns = self._sf_sample.df_outcomes
outcome_impact = None
if weights_impact_on_outcome_method is not None and outcome_columns is not None:
outcome_impact = self.outcomes().weights_impact_on_outcome_ss(
method=weights_impact_on_outcome_method,
conf_level=weights_impact_on_outcome_conf_level,
round_ndigits=None,
)
target = self._sf_target
assert target is not None, "diagnostics() requires a target"
result = _build_diagnostics(
covars_df=self.covars().df,
target_covars_df=target.df_covars,
weights_summary=self.weights().summary(),
model_dict=self.model,
covars_asmd=self.covars().asmd(),
covars_asmd_main=self.covars().asmd(aggregate_by_main_covar=True),
outcome_columns=outcome_columns,
weights_impact_on_outcome_method=weights_impact_on_outcome_method,
weights_impact_on_outcome_conf_level=weights_impact_on_outcome_conf_level,
outcome_impact=outcome_impact,
)
logger.info("Done computing diagnostics")
return result
# --- Parity helpers ---
# --- DataFrame / export ---
@property
def df_all(self) -> pd.DataFrame:
"""Combined DataFrame with all samples, distinguished by a ``"source"`` column.
Concatenates the responder, target, and (if adjusted) unadjusted
DataFrames vertically, adding a ``"source"`` column with values
``"self"``, ``"target"``, and ``"unadjusted"`` respectively.
Returns:
pd.DataFrame: A DataFrame with all rows from responder, target,
and optionally unadjusted SampleFrames, plus a ``"source"``
column.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> bf.df_all["source"].unique().tolist()
['self', 'target']
"""
parts: list[pd.DataFrame] = []
resp_df = self._sf_sample._df.copy()
resp_df["source"] = "self"
parts.append(resp_df)
if self._sf_target is not None:
tgt_df = self._sf_target._df.copy()
tgt_df["source"] = "target"
parts.append(tgt_df)
if self.is_adjusted:
unadj_df = self._sf_sample_pre_adjust._df.copy()
unadj_df["source"] = "unadjusted"
parts.append(unadj_df)
return pd.concat(parts, ignore_index=True)
@property
def df(self) -> pd.DataFrame:
"""Flat user-facing DataFrame from the responders.
Returns the responder data with columns ordered as:
id → covariates → outcomes → weight → ignored.
Returns:
pd.DataFrame: Ordered copy of the responder's data.
"""
covars = self.covars()
outcomes = self.outcomes()
ignored = self._sf_sample.df_ignored
return pd.concat(
(
self.id_series,
covars.df if covars is not None else None,
outcomes.df if outcomes is not None else None,
(
pd.DataFrame(self.weight_series)
if self.weight_series is not None
else None
),
ignored if ignored is not None else None,
),
axis=1,
)
[docs]
def keep_only_some_rows_columns(
self,
rows_to_keep: str | None = None,
columns_to_keep: list[str] | None = None,
) -> BalanceFrame:
"""Return a new BalanceFrame with filtered rows and/or columns.
Returns a deep copy with the requested subset applied to the
responder, target, and (if adjusted) unadjusted SampleFrames.
The original BalanceFrame is unchanged (immutable pattern).
Args:
rows_to_keep: A boolean expression string evaluated via
``pd.DataFrame.eval`` to select rows. Applied to each
SampleFrame's underlying DataFrame. For example:
``'x > 10'`` or ``'gender == "Female"'``.
Defaults to None (all rows kept).
columns_to_keep: Covariate column names to retain. Special
columns (id, weight) are always kept. Defaults to None
(all columns kept).
Returns:
BalanceFrame: A new BalanceFrame with the filters applied.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2, 3], "x": [10.0, 20.0, 30.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [4, 5, 6], "x": [15.0, 25.0, 35.0],
... "weight": [1.0, 1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> filtered = bf.keep_only_some_rows_columns(rows_to_keep="x > 15")
>>> len(filtered._sf_sample._df)
2
"""
if rows_to_keep is None and columns_to_keep is None:
return self
new_bf = copy.deepcopy(self)
if new_bf.has_target():
# With target: filter all SampleFrames
new_bf._sf_sample = BalanceFrame._filter_sf(
new_bf._sf_sample, rows_to_keep, columns_to_keep
)
if new_bf._sf_target is not None:
new_bf._sf_target = BalanceFrame._filter_sf(
new_bf._sf_target, rows_to_keep, columns_to_keep
)
if new_bf._sf_sample_pre_adjust is not None:
new_bf._sf_sample_pre_adjust = BalanceFrame._filter_sf(
new_bf._sf_sample_pre_adjust, rows_to_keep, columns_to_keep
)
else:
# No target: filter the responder SampleFrame
if columns_to_keep is not None:
if not (set(columns_to_keep) <= set(new_bf.df.columns)):
logger.warning(
"Note that not all columns_to_keep are in Sample. "
"Only those that exist are removed"
)
new_bf._sf_sample = BalanceFrame._filter_sf(
new_bf._sf_sample, rows_to_keep, columns_to_keep
)
if (
new_bf._sf_sample_pre_adjust is not None
and new_bf._sf_sample_pre_adjust is not new_bf._sf_sample
):
new_bf._sf_sample_pre_adjust = BalanceFrame._filter_sf(
new_bf._sf_sample_pre_adjust, rows_to_keep, columns_to_keep
)
# Also filter linked BF/Sample objects in _links
if new_bf._links:
for k, v in list(new_bf._links.items()):
if isinstance(v, BalanceFrame):
try:
new_bf._links[k] = v.keep_only_some_rows_columns(
rows_to_keep=rows_to_keep,
columns_to_keep=columns_to_keep,
)
except (TypeError, ValueError, AttributeError, KeyError) as exc:
logger.warning(
"couldn't filter _links['%s'] using provided filters: %s",
k,
exc,
)
return new_bf
@staticmethod
def _filter_sf(
sf: SampleFrame,
rows_to_keep: str | None,
columns_to_keep: list[str] | None,
) -> SampleFrame:
"""Apply row and column filtering to a SampleFrame in place.
Used internally by :meth:`keep_only_some_rows_columns` to filter
each SampleFrame (responders, target, unadjusted) consistently.
Args:
sf: The SampleFrame to filter (mutated in place).
rows_to_keep: A pandas ``eval()`` expression for row filtering,
or ``None`` to skip row filtering.
columns_to_keep: Column names to retain, or ``None`` to skip
column filtering. ID and weight columns are always retained.
Returns:
SampleFrame: The same *sf* instance, mutated.
Note:
If ``rows_to_keep`` references a column that does not exist in
*sf*, the ``UndefinedVariableError`` is caught, a warning is
logged, and row filtering is skipped for that SampleFrame. This
is intentional: linked frames (target, unadjusted) may not have
the same columns as the responder, so a filter expression valid
for the responder may fail on linked frames. This matches the
``Sample.keep_only_some_rows_columns`` behaviour.
"""
df = sf._df
if rows_to_keep is not None:
try:
mask = df.eval(rows_to_keep)
logger.info(f"(rows_filtered/total_rows) = ({mask.sum()}/{len(mask)})")
df = df[mask].reset_index(drop=True)
except pd.errors.UndefinedVariableError:
logger.warning(f"couldn't filter SampleFrame using {rows_to_keep}")
if columns_to_keep is not None:
keep_set = set(columns_to_keep)
keep_set.add(sf._id_column_name)
if sf._weight_column_name is not None:
keep_set.add(sf._weight_column_name)
for wc in sf.weight_columns_all:
keep_set.add(wc)
# Always preserve outcome columns (matching Sample behavior)
for oc in sf._column_roles.get("outcomes", []):
keep_set.add(oc)
df = df.loc[:, df.columns.isin(keep_set)]
new_covars = [c for c in sf._column_roles["covars"] if c in keep_set]
sf._column_roles = dict(sf._column_roles)
sf._column_roles["covars"] = new_covars
if sf._column_roles["outcomes"]:
sf._column_roles["outcomes"] = [
c for c in sf._column_roles["outcomes"] if c in keep_set
]
if sf._column_roles["predicted"]:
sf._column_roles["predicted"] = [
c for c in sf._column_roles["predicted"] if c in keep_set
]
if sf._column_roles["ignored"]:
sf._column_roles["ignored"] = [
c for c in sf._column_roles["ignored"] if c in keep_set
]
sf._df = df
return sf
[docs]
def to_csv(
self, path_or_buf: FilePathOrBuffer | None = None, **kwargs: Any
) -> str | None:
"""Write the combined DataFrame to CSV.
Writes the output of :attr:`df` (responder + target + unadjusted
rows with a ``"source"`` column) to a CSV file or string.
Delegates to :func:`~balance.csv_utils.to_csv_with_defaults`.
Args:
path_or_buf: Destination. If ``None``, returns the CSV as a string.
**kwargs: Additional keyword arguments passed to
:func:`pd.DataFrame.to_csv`.
Returns:
str or None: CSV string if ``path_or_buf`` is None, else None.
Examples:
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> "id" in bf.to_csv()
True
"""
return to_csv_with_defaults(self.df, path_or_buf, **kwargs)
[docs]
def to_download(self, tempdir: str | None = None) -> Any:
"""Create a downloadable file link of the combined DataFrame.
Writes :attr:`df` to a temporary CSV file and returns an IPython
:class:`~IPython.lib.display.FileLink` for interactive download.
Args:
tempdir: Directory for the temp file. If None, uses
:func:`tempfile.gettempdir`.
Returns:
FileLink: An IPython file link for downloading the CSV.
Examples:
>>> import tempfile
>>> import pandas as pd
>>> from balance.sample_frame import SampleFrame
>>> from balance.balance_frame import BalanceFrame
>>> resp = SampleFrame.from_frame(
... pd.DataFrame({"id": [1, 2], "x": [10.0, 20.0], "weight": [1.0, 1.0]}))
>>> tgt = SampleFrame.from_frame(
... pd.DataFrame({"id": [3, 4], "x": [15.0, 25.0], "weight": [1.0, 1.0]}))
>>> bf = BalanceFrame(sample=resp, target=tgt)
>>> link = bf.to_download(tempdir=tempfile.gettempdir())
"""
return _to_download(self.df, tempdir)
# --- Methods moved from Sample ---
[docs]
def model_matrix(self) -> pd.DataFrame:
"""Return the model matrix of the responder covariates.
Constructs a model matrix using :func:`balance.util.model_matrix`,
adding NA indicators for null values.
Returns:
pd.DataFrame: The model matrix.
"""
res = _assert_type(
balance_util.model_matrix(self, add_na=True)["sample"], pd.DataFrame
)
return res
[docs]
def set_weights(
self,
weights: pd.Series | float | None,
*,
use_index: bool = False,
) -> None:
"""Set or replace the responder weights.
Delegates to the underlying SampleFrame's ``set_weights``.
When called on an unadjusted BalanceFrame (``is_adjusted`` is False),
``_sf_sample`` and ``_sf_sample_pre_adjust`` share the same DataFrame,
so the change is visible to both automatically — changing base weights
is not an adjustment.
.. warning::
If this BalanceFrame has already been fitted (i.e., ``adjust()``
has been called), calling ``set_weights()`` changes the design
weights but does **not** invalidate the stored fit artifacts
(``_adjustment_model``). The link values
in those artifacts were computed using the old weights, so
``predict_weights()`` will use new ``current_sample_weights``
with stale links, producing a mathematical inconsistency.
Users should re-fit (call ``adjust()`` again) after changing
weights on an already-fitted BalanceFrame.
Args:
weights: New weights. A Series, a scalar (broadcast to all rows),
or ``None`` (sets all to 1.0).
use_index: If True, align *weights* by index instead of requiring
matching length. See :meth:`SampleFrame.set_weights`.
"""
self._sf_sample.set_weights(weights, use_index=use_index)
[docs]
def trim(
self,
ratio: float | int | None = None,
percentile: float | tuple[float, float] | None = None,
keep_sum_of_weights: bool = True,
target_sum_weights: float | int | np.floating | None = None,
*,
inplace: bool = False,
) -> Self:
"""Trim extreme weights using mean-ratio clipping or percentile winsorization.
Delegates to :meth:`SampleFrame.trim` for computation and weight
history tracking, then wraps the result in a new BalanceFrame
(preserving target, pre-adjust baseline, and links).
Args:
ratio: Mean-ratio upper bound. Mutually exclusive with
*percentile*.
percentile: Percentile(s) for winsorization. Mutually exclusive
with *ratio*.
keep_sum_of_weights: Whether to rescale after trimming to
preserve the original sum of weights.
target_sum_weights: If provided, rescale trimmed weights so
their sum equals this target.
inplace: If True, mutate this BalanceFrame's weights and
return it. If False (default), return a new BalanceFrame.
Returns:
The BalanceFrame with trimmed weights (self if *inplace*,
else a new instance).
"""
if inplace:
self._sf_sample.trim(
ratio=ratio,
percentile=percentile,
keep_sum_of_weights=keep_sum_of_weights,
target_sum_weights=target_sum_weights,
inplace=True,
)
return self
new_sf = self._sf_sample.trim(
ratio=ratio,
percentile=percentile,
keep_sum_of_weights=keep_sum_of_weights,
target_sum_weights=target_sum_weights,
inplace=False,
)
new_bf = type(self)._create(
sample=new_sf,
target=self._sf_target,
)
new_bf._sf_sample_pre_adjust = self._sf_sample_pre_adjust
new_bf._adjustment_model = self._adjustment_model
# Preserve existing links (target, unadjusted).
for key, val in self._links.items():
new_bf._links[key] = val
return new_bf
[docs]
def set_unadjusted(self, second: BalanceFrame) -> Self:
"""Set the unadjusted link for comparative analysis.
Returns a deep copy with ``_sf_sample_pre_adjust`` pointing at
*second*'s responder SampleFrame, and ``_links["unadjusted"]``
pointing at *second*.
Args:
second: A BalanceFrame (or subclass) whose responder data
becomes the unadjusted baseline.
Returns:
A new BalanceFrame with the unadjusted link set.
Raises:
TypeError: If *second* is not a BalanceFrame.
"""
if not isinstance(second, BalanceFrame):
raise TypeError(
f"set_unadjusted must be called with a BalanceFrame argument, got {type(second).__name__}"
)
new_bf = deepcopy(self)
new_bf._links["unadjusted"] = second
new_bf._sf_sample_pre_adjust = second._sf_sample
return new_bf
# --- Column accessors (moved from Sample) ---
def _special_columns_names(self) -> list[str]:
"""Return names of all special columns (id, weight, outcome, ignored)."""
return (
[str(i.name) for i in [self.id_series, self.weight_series] if i is not None]
+ (
self._outcome_columns.columns.tolist()
if self._outcome_columns is not None
else []
)
+ getattr(self, "_ignored_column_names", [])
)
def _special_columns(self) -> pd.DataFrame:
"""Return a DataFrame of all special columns."""
return self._df[self._special_columns_names()]
def _covar_columns_names(self) -> list[str]:
"""Return names of all covariate columns."""
return [
c for c in self._df.columns.values if c not in self._special_columns_names()
]
def _covar_columns(self) -> pd.DataFrame:
"""Return a DataFrame of all covariate columns."""
return self._sf_sample._covar_columns()
# --- Error checks (moved from Sample) ---
def _require_adjusted(self) -> None:
"""Raise ValueError if not adjusted."""
if not self.is_adjusted:
raise ValueError(
f"This {type(self).__name__} is not adjusted. "
"Use .adjust() to adjust to target."
)
def _require_target(self) -> None:
"""Raise ValueError if no target is set."""
if not self.has_target():
raise ValueError(
f"This {type(self).__name__} does not have a target set. "
"Use .set_target() to add a target."
)
def _require_outcomes(self) -> None:
"""Raise ValueError if no outcome columns are specified."""
if self.outcomes() is None:
raise ValueError(
f"This {type(self).__name__} does not have outcome columns specified."
)
def __repr__(self) -> str:
return (
f"({self.__class__.__module__}.{self.__class__.__qualname__})\n"
f"{self.__str__()}"
)
def __str__(self, pkg_source: str | None = None) -> str:
"""Return a readable summary of the sample and any applied adjustment.
Args:
pkg_source: Package namespace used in the header. Defaults to
the module's ``__package__``.
Returns:
str: Multi-line description highlighting key structure and
adjustment details.
"""
if pkg_source is None:
pkg_source = __package__
is_adjusted = self.is_adjusted() * "Adjusted "
n_rows = self._df.shape[0]
n_variables = self._covar_columns().shape[1]
has_target = self.has_target() * " with target set"
adjustment_method = (
" using " + _assert_type(self.model)["method"]
if self.model is not None
else ""
)
variables = ",".join(self._covar_columns_names())
id_column_name = self.id_series.name if self.id_series is not None else "None"
weight_column_name = (
self.weight_series.name if self.weight_series is not None else "None"
)
outcome_column_names = (
",".join(self._outcome_columns.columns.tolist())
if self._outcome_columns is not None
else "None"
)
desc = f"""
{is_adjusted}{pkg_source} Sample object{has_target}{adjustment_method}
{n_rows} observations x {n_variables} variables: {variables}
id_column: {id_column_name}, weight_column: {weight_column_name},
outcome_columns: {outcome_column_names}
"""
if self.is_adjusted():
adjustment_details = self._quick_adjustment_details(n_rows)
if len(adjustment_details) > 0:
desc += """
adjustment details:
{details}
""".format(
details="\n ".join(adjustment_details)
)
if self.has_target():
common_variables = balance_util.choose_variables(
self, self._links["target"], variables=None
)
target_str = self._links["target"].__str__().replace("\n", "\n\t")
n_common = len(common_variables)
common_variables = ",".join(common_variables)
desc += f"""
target:
{target_str}
{n_common} common variables: {common_variables}
"""
return desc