Source code for balance.stats_and_plots.love_plot

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""Love plot — visual covariate-imbalance diagnostic.

A "Love plot" (after Thomas Love) is the canonical visual for showing how
much each covariate's imbalance shrinks after applying weights. ``balance``
exposes a primitive (``love_plot``) operating on raw ``pd.Series`` inputs
and a method shortcut (``BalanceDFCovars.love_plot``) that pulls the chosen
metric off a fitted ``BalanceFrame``'s lineage.

The primitive supports static seaborn output, plotly output, and ASCII text
output. With both ``before`` and ``after`` it draws the canonical
before-vs-after view; with only ``before`` it draws a single-series
pre-adjust diagnostic.
"""

from __future__ import annotations

import logging
import math
import numbers
import re
from typing import Any, cast, get_args, Literal

import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

logger: logging.Logger = logging.getLogger(__package__)

_BEFORE_COLOR: str = "#888888"
_AFTER_COLOR: str = "#0072B2"
_THRESHOLD_COLOR: str = "red"

_SUMMARY_ROW_PATTERN: re.Pattern[str] = re.compile(
    r"^mean\((?:asmd|kld|emd|cvmd|ks)\)$"
)
LovePlotLibrary = Literal["plotly", "seaborn", "balance"]
LovePlotOrderBy = Literal["before", "after", "diff", "alphabetical", "none"]
_LOVE_PLOT_LIBRARIES: tuple[str, ...] = cast(tuple[str, ...], get_args(LovePlotLibrary))
_LOVE_PLOT_ORDER_BY: tuple[str, ...] = cast(tuple[str, ...], get_args(LovePlotOrderBy))


def _drop_summary_rows(s: pd.Series) -> pd.Series:
    """Drop any ``mean(<asmd|kld|emd|cvmd|ks>)`` summary row from a metric series."""
    keep_mask: pd.Series = (
        ~s.index.astype(str).to_series(index=s.index).str.match(_SUMMARY_ROW_PATTERN)
    )
    return s[keep_mask]


def _validate_metric_series(series: pd.Series, *, name: str) -> pd.Series:
    """Validate and coerce a love-plot metric series to finite numeric values."""
    if not isinstance(series, pd.Series):
        raise TypeError(f"{name} must be a pandas Series; got {type(series)!r}.")

    clean = _drop_summary_rows(series)
    if clean.index.has_duplicates:
        duplicated = clean.index[clean.index.duplicated()].astype(str).unique().tolist()
        raise ValueError(
            f"{name} contains duplicate covariate labels after summary-row drop: "
            f"{duplicated}. Love plots require one metric value per covariate."
        )

    try:
        numeric = pd.to_numeric(clean, errors="raise")
    except (TypeError, ValueError) as exc:
        raise ValueError(
            f"{name} must contain numeric imbalance metric values."
        ) from exc

    finite_or_na = numeric.isna() | np.isfinite(numeric.astype(float))
    if not finite_or_na.all():
        bad_labels = numeric.index[~finite_or_na].astype(str).tolist()
        raise ValueError(
            f"{name} contains non-finite imbalance values for covariates: "
            f"{bad_labels}."
        )
    return numeric


def _prepare_love_plot_data(
    before: pd.Series,
    after: pd.Series | None,
    *,
    order_by: LovePlotOrderBy,
) -> pd.DataFrame:
    """Clean, align, and sort love-plot input series."""
    if order_by not in _LOVE_PLOT_ORDER_BY:
        raise ValueError(
            f"order_by must be one of {_LOVE_PLOT_ORDER_BY}; got {order_by!r}."
        )

    before_clean: pd.Series = _validate_metric_series(before, name="before")
    if after is None:
        before_clean = before_clean.dropna()
        if before_clean.empty:
            raise ValueError(
                "love_plot: no covariates to plot — all entries are NaN "
                "(this usually means the categorical levels in the sample / "
                "target frames are completely disjoint)."
            )
        data = pd.DataFrame({"value": before_clean})
    else:
        after_clean: pd.Series = _validate_metric_series(after, name="after")
        if not before_clean.index.equals(after_clean.index):
            common = before_clean.index.intersection(after_clean.index)
            if len(common) == 0:
                raise ValueError("before and after share no covariates.")
            logger.warning(
                "love_plot: aligning to %d common covariates (before=%d, after=%d).",
                len(common),
                len(before_clean),
                len(after_clean),
            )
            before_clean = before_clean.loc[common]
            after_clean = after_clean.loc[common]

        keep = (~before_clean.isna()) & (~after_clean.isna())
        if not keep.any():
            raise ValueError(
                "love_plot: no covariates to plot after dropping NaN entries "
                "(this usually means the categorical levels in the sample / "
                "target frames are completely disjoint)."
            )
        data = pd.DataFrame(
            {"Unweighted": before_clean[keep], "Weighted": after_clean[keep]}
        )

    if order_by == "alphabetical":
        # Keep the internal bottom-to-top order used by seaborn/plotly,
        # but sort labels by their string representation so mixed-type
        # covariate labels (e.g. ints and strings) do not raise in pandas.
        order = sorted(data.index, key=lambda x: str(x), reverse=True)
        return data.loc[order]
    if order_by == "none":
        return data
    if order_by == "after" and "Weighted" in data.columns:
        order_values = data["Weighted"].abs()
    elif order_by == "diff" and "Weighted" in data.columns:
        # Signed difference (after - before): negative = improvement,
        # positive = worsening. ascending=True puts the smallest (most-
        # improved) at y=0 (bottom) and the most-worsened at the top, so
        # regressions float to the top of the plot.
        order_values = data["Weighted"] - data["Unweighted"]
    else:
        order_values = data.iloc[:, 0].abs()
    order = order_values.sort_values(ascending=True).index
    return data.loc[order]


def _seaborn_love_plot(
    data: pd.DataFrame,
    *,
    xlabel: str,
    threshold: float | None,
    ax: matplotlib.axes.Axes | None,
    line: bool,
) -> matplotlib.axes.Axes:
    if ax is None:
        _, ax = plt.subplots(figsize=(6, max(3, 0.3 * len(data))))

    y = np.arange(len(data))
    if list(data.columns) == ["value"]:
        sns.scatterplot(
            x=data["value"].values,
            y=y,
            marker="o",
            color=_AFTER_COLOR,
            label=xlabel,
            ax=ax,
        )
    else:
        if line:
            ax.hlines(
                y=y,
                xmin=data["Unweighted"].values,
                xmax=data["Weighted"].values,
                colors="#BBBBBB",
                linewidth=1,
                zorder=1,
            )
        sns.scatterplot(
            x=data["Unweighted"].values,
            y=y,
            marker="o",
            label="Unweighted",
            color=_BEFORE_COLOR,
            ax=ax,
            zorder=2,
        )
        sns.scatterplot(
            x=data["Weighted"].values,
            y=y,
            marker="s",
            label="Weighted",
            color=_AFTER_COLOR,
            ax=ax,
            zorder=3,
        )
    ax.set_yticks(y)
    ax.set_yticklabels(data.index)
    if threshold is not None:
        ax.axvline(threshold, linestyle="--", color=_THRESHOLD_COLOR, alpha=0.5)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Covariate")
    ax.legend(loc="best")
    ax.grid(axis="x", alpha=0.3)
    return ax


def _safe_plotly_show(fig: Any) -> None:
    """Display a Plotly figure, tolerating missing notebook mime renderers.

    ``fig.show()`` raises ``ValueError`` in CLI/test environments where Plotly's
    mime renderer cannot import ``nbformat``. Mirrors the guard used by
    :func:`weighted_comparisons_plots._safe_plotly_iplot`: log a warning and
    return the figure rather than crashing.
    """
    from balance.stats_and_plots.weighted_comparisons_plots import (
        _is_nbformat_mime_error,
    )

    try:
        fig.show()
    except ValueError as error:
        if not _is_nbformat_mime_error(error):
            raise
        logger.warning(
            "Plotly notebook mime rendering unavailable; returning figure "
            "without displaying it. Original error: %s",
            error,
        )


def _plotly_love_plot(
    data: pd.DataFrame,
    *,
    xlabel: str,
    threshold: float | None,
    line: bool,
    show: bool,
    **layout_kwargs: Any,
) -> Any:
    import plotly.graph_objects as go

    fig = go.Figure()
    # Use numeric y positions with explicit tick labels rather than a
    # categorical y-axis. Distinct index values that stringify to the same
    # text (e.g. ``1`` and ``"1"``) would otherwise collapse onto the same
    # categorical row in Plotly, overlapping markers and connector lines.
    y_positions = list(range(len(data)))
    y_tick_labels = [str(i) for i in data.index]
    if list(data.columns) == ["value"]:
        fig.add_trace(
            go.Scatter(
                x=data["value"].tolist(),
                y=y_positions,
                mode="markers",
                marker={"symbol": "circle", "color": _AFTER_COLOR},
                name=xlabel,
            )
        )
    else:
        if line:
            x_values: list[float | None] = []
            y_values: list[float | None] = []
            for y_pos, (_covar, row) in zip(y_positions, data.iterrows()):
                x_values.extend(
                    [float(row["Unweighted"]), float(row["Weighted"]), None]
                )
                y_values.extend([y_pos, y_pos, None])
            fig.add_trace(
                go.Scatter(
                    x=x_values,
                    y=y_values,
                    mode="lines",
                    line={"color": "#BBBBBB", "width": 1},
                    hoverinfo="skip",
                    showlegend=False,
                    name="Change",
                )
            )
        fig.add_trace(
            go.Scatter(
                x=data["Unweighted"].tolist(),
                y=y_positions,
                mode="markers",
                marker={"symbol": "circle", "color": _BEFORE_COLOR},
                name="Unweighted",
            )
        )
        fig.add_trace(
            go.Scatter(
                x=data["Weighted"].tolist(),
                y=y_positions,
                mode="markers",
                marker={"symbol": "square", "color": _AFTER_COLOR},
                name="Weighted",
            )
        )
    if threshold is not None:
        fig.add_shape(
            type="line",
            x0=threshold,
            x1=threshold,
            xref="x",
            y0=0,
            y1=1,
            yref="paper",
            line={"color": _THRESHOLD_COLOR, "dash": "dash"},
            opacity=0.5,
        )
    fig.update_layout(
        xaxis_title=xlabel,
        yaxis_title="Covariate",
        yaxis={"tickmode": "array", "tickvals": y_positions, "ticktext": y_tick_labels},
        height=max(300, 30 * len(data)),
        template="plotly_white",
    )
    if layout_kwargs:
        fig.update_layout(**layout_kwargs)
    if show:
        _safe_plotly_show(fig)
    return fig


def _ascii_bar(value: float, max_value: float, *, width: int, char: str) -> str:
    if max_value <= 0:
        return ""
    n_chars = int(round((abs(value) / max_value) * width))
    if value != 0 and n_chars == 0:
        n_chars = 1
    return char * n_chars


def _ascii_love_plot(
    data: pd.DataFrame,
    *,
    xlabel: str,
    threshold: float | None,
    bar_width: int,
) -> str:
    max_label_width = max(len(str(i)) for i in data.index)
    covar_width = min(max(max_label_width, len("Covariate")), 40)
    max_value = float(data.abs().max().max())
    threshold_text = "none" if threshold is None else f"{threshold:.3g}"
    lines = [
        f"Love plot ({xlabel})",
        f"Threshold: {threshold_text}",
    ]
    display_data = data.iloc[::-1]
    if list(data.columns) == ["value"]:
        lines.append(f"{'Covariate':<{covar_width}} | {xlabel:>10} | Plot")
        lines.append("-" * (covar_width + bar_width + 18))
        for covar, row in display_data.iterrows():
            value = float(row["value"])
            bar = _ascii_bar(value, max_value, width=bar_width, char="#")
            lines.append(
                f"{str(covar):<{covar_width}.{covar_width}} | {value:>10.4g} | {bar}"
            )
    else:
        lines.append(
            f"{'Covariate':<{covar_width}} | {'Unweighted':>10} | {'Weighted':>10} | Change"
        )
        lines.append("-" * (covar_width + bar_width + 40))
        for covar, row in display_data.iterrows():
            before_value = float(row["Unweighted"])
            after_value = float(row["Weighted"])
            before_bar = _ascii_bar(before_value, max_value, width=bar_width, char=".")
            after_bar = _ascii_bar(after_value, max_value, width=bar_width, char="#")
            direction = "improved" if abs(after_value) <= abs(before_value) else "worse"
            lines.append(
                f"{str(covar):<{covar_width}.{covar_width}} | "
                f"{before_value:>10.4g} | {after_value:>10.4g} | "
                f"{before_bar} -> {after_bar} ({direction})"
            )
    return "\n".join(lines)


[docs] def love_plot( before: pd.Series, after: pd.Series | None = None, *, xlabel: str = "ASMD", threshold: float | None = 0.1, ax: matplotlib.axes.Axes | None = None, library: LovePlotLibrary = "plotly", line: bool = False, order_by: LovePlotOrderBy = "diff", show: bool = False, bar_width: int = 30, **layout_kwargs: Any, ) -> Any: """Plot per-covariate imbalance, before vs. after weighting. Args: before: Per-covariate metric values before adjustment, or the only series to plot when ``after`` is ``None``. after: Per-covariate metric values after adjustment. xlabel: Metric label for the x-axis and ASCII header. threshold: Optional non-negative vertical/reference threshold. Pass ``None`` to skip it. ax: Optional matplotlib ``Axes`` for ``library="seaborn"``. library: One of ``"plotly"`` (default; interactive ``plotly.graph_objects.Figure``), ``"seaborn"`` (static seaborn/matplotlib axes), or ``"balance"`` (ASCII string). line: If ``True`` and both series are supplied, connect each before/after pair with a horizontal line. order_by: Covariate sorting. ``"diff"`` (default) orders by signed ``after - before`` so the most-improved covariates are at the bottom and the most-worsened are at the top; ``"before"`` / ``"after"`` order by absolute pre / post values; ``"alphabetical"`` sorts by covariate name; ``"none"`` keeps input order. show: For ``library="plotly"``, whether to call ``fig.show()``. bar_width: Width of ASCII bars for ``library="balance"``. **layout_kwargs: Additional Plotly layout options when ``library="plotly"``. Returns: ``matplotlib.axes.Axes`` for static output, ``plotly`` ``Figure`` for plotly output, or ``str`` for ASCII output. Examples: :: >>> import pandas as pd >>> from balance.stats_and_plots.love_plot import love_plot >>> before = pd.Series({"age": 0.42, "income": 0.31}) >>> after = pd.Series({"age": 0.05, "income": 0.08}) >>> fig = love_plot(before, after) # doctest: +SKIP >>> ax = love_plot( # doctest: +SKIP ... before, after, library="seaborn", line=True ... ) >>> text = love_plot( # doctest: +SKIP ... before, after, library="balance", line=True ... ) """ if threshold is not None: if not isinstance(threshold, numbers.Real) or isinstance(threshold, bool): raise TypeError( f"threshold must be a non-negative finite number or None; got {type(threshold)!r}." ) if not math.isfinite(float(threshold)) or threshold < 0: raise ValueError("threshold must be non-negative and finite, or None.") if library not in _LOVE_PLOT_LIBRARIES: raise ValueError( f"library must be one of {_LOVE_PLOT_LIBRARIES}; got {library!r}." ) if not isinstance(line, bool): raise TypeError(f"line must be a bool; got {type(line)!r}.") if not isinstance(show, bool): raise TypeError(f"show must be a bool; got {type(show)!r}.") if not isinstance(bar_width, int) or isinstance(bar_width, bool) or bar_width <= 0: raise ValueError("bar_width must be a positive integer.") if ax is not None and library != "seaborn": raise ValueError("ax can only be used with library='seaborn'.") data = _prepare_love_plot_data(before, after, order_by=order_by) if library == "seaborn": if layout_kwargs: logger.warning( "Ignoring plotly layout kwargs for library=%r: %s", library, sorted(layout_kwargs.keys()), ) return _seaborn_love_plot( data, xlabel=xlabel, threshold=threshold, ax=ax, line=line ) if library == "plotly": return _plotly_love_plot( data, xlabel=xlabel, threshold=threshold, line=line, show=show, **layout_kwargs, ) return _ascii_love_plot( data, xlabel=xlabel, threshold=threshold, bar_width=bar_width )