Comparing results of fitting CBPS between R's CBPS package and Python's balance package (using simulated data)¶
This notebook shows how we can reproduce (almost exactly) the weights produced from R's CBPS package, using the implementation in balance.
The example is based on simulated data that that was provided in the help page of the CBPS function (i.e.: ?CBPS::CBPS, you can see it here).
The R code used to create the data is available here.
Loading data and fitting CBPS using balance¶
In [1]:
import balance
import numpy as np
import pandas as pd
import session_info
from balance import Sample
INFO (2026-02-21 04:52:54,141) [__init__/<module> (line 72)]: Using balance version 0.16.1
balance (Version 0.16.1) loaded:
📖 Documentation: https://import-balance.org/
🛠️ Help / Issues: https://github.com/facebookresearch/balance/issues/
📄 Citation:
Sarig, T., Galili, T., & Eilat, R. (2023).
balance - a Python package for balancing biased data samples.
https://arxiv.org/abs/2307.06024
Tip: You can view this message anytime with balance.help()
In [2]:
target_df, sample_df = balance.datasets.load_data("sim_data_cbps")
# print(target_df.head())
print(target_df.info())
<class 'pandas.DataFrame'> Index: 254 entries, 1 to 498 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 X1 254 non-null float64 1 X2 254 non-null float64 2 X3 254 non-null float64 3 X4 254 non-null float64 4 cbps_weights 254 non-null float64 5 y 254 non-null float64 6 id 254 non-null int64 dtypes: float64(6), int64(1) memory usage: 15.9 KB None
In [3]:
sample = Sample.from_frame(sample_df, outcome_columns = ['y', 'cbps_weights'])
target = Sample.from_frame(target_df, outcome_columns = ['y', 'cbps_weights'])
sample_target = sample.set_target(target)
WARNING (2026-02-21 04:52:54,367) [input_validation/guess_id_column (line 337)]: Guessed id column name id for the data
WARNING (2026-02-21 04:52:54,368) [sample_class/from_frame (line 469)]: Casting id column to string
WARNING (2026-02-21 04:52:54,379) [pandas_utils/_warn_of_df_dtypes_change (line 514)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2026-02-21 04:52:54,380) [pandas_utils/_warn_of_df_dtypes_change (line 525)]: The (old) dtypes that changed for df (before the change):
WARNING (2026-02-21 04:52:54,381) [pandas_utils/_warn_of_df_dtypes_change (line 528)]: id int64 dtype: object
WARNING (2026-02-21 04:52:54,381) [pandas_utils/_warn_of_df_dtypes_change (line 529)]: The (new) dtypes saved in df (after the change):
WARNING (2026-02-21 04:52:54,383) [pandas_utils/_warn_of_df_dtypes_change (line 530)]: id str dtype: object
WARNING (2026-02-21 04:52:54,383) [sample_class/from_frame (line 549)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2026-02-21 04:52:54,394) [input_validation/guess_id_column (line 337)]: Guessed id column name id for the data
WARNING (2026-02-21 04:52:54,395) [sample_class/from_frame (line 469)]: Casting id column to string
WARNING (2026-02-21 04:52:54,405) [pandas_utils/_warn_of_df_dtypes_change (line 514)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2026-02-21 04:52:54,405) [pandas_utils/_warn_of_df_dtypes_change (line 525)]: The (old) dtypes that changed for df (before the change):
WARNING (2026-02-21 04:52:54,407) [pandas_utils/_warn_of_df_dtypes_change (line 528)]: id int64 dtype: object
WARNING (2026-02-21 04:52:54,407) [pandas_utils/_warn_of_df_dtypes_change (line 529)]: The (new) dtypes saved in df (after the change):
WARNING (2026-02-21 04:52:54,409) [pandas_utils/_warn_of_df_dtypes_change (line 530)]: id str dtype: object
WARNING (2026-02-21 04:52:54,409) [sample_class/from_frame (line 549)]: No weights passed. Adding a 'weight' column and setting all values to 1
In [4]:
# adjust = sample_target.adjust(method = "cbps") # the defaults of the function would not yield similar-enough results, so we need to adjust some parameters:
adjust = sample_target.adjust(method = "cbps", transformations = None, weight_trimming_mean_ratio = None)
INFO (2026-02-21 04:52:54,418) [cbps/cbps (line 537)]: Starting cbps function
INFO (2026-02-21 04:52:54,431) [cbps/cbps (line 588)]: The formula used to build the model matrix: ['X4 + X3 + X2 + X1']
INFO (2026-02-21 04:52:54,432) [cbps/cbps (line 599)]: The number of columns in the model matrix: 4
INFO (2026-02-21 04:52:54,433) [cbps/cbps (line 600)]: The number of rows in the model matrix: 500
INFO (2026-02-21 04:52:54,435) [cbps/cbps (line 669)]: Finding initial estimator for GMM optimization
INFO (2026-02-21 04:52:54,443) [cbps/cbps (line 696)]: Finding initial estimator for GMM optimization that minimizes the balance loss
INFO (2026-02-21 04:52:54,936) [cbps/cbps (line 732)]: Running GMM optimization
INFO (2026-02-21 04:52:55,414) [cbps/cbps (line 859)]: Done cbps function
Comparing results of balance and CBPS¶
In [5]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.scatter(adjust.df["cbps_weights"], adjust.df["weight"], color="blue")
ax.set_xlabel("cbps_weights")
ax.set_ylabel("weight")
ax.set_title("CBPS weights vs. weights")
adjust.df[["cbps_weights", "weight"]].corr(method = "pearson")
Out[5]:
| cbps_weights | weight | |
|---|---|---|
| cbps_weights | 1.000000 | 0.980463 |
| weight | 0.980463 | 1.000000 |
In [6]:
log_adjust = adjust.df.loc[(adjust.df["cbps_weights"] > 0) & (adjust.df["weight"] > 0), ["cbps_weights", "weight"]].copy()
log_adjust["log_cbps_weights"] = np.log10(log_adjust["cbps_weights"])
log_adjust["log_weight"] = np.log10(log_adjust["weight"])
fig, ax = plt.subplots()
ax.scatter(log_adjust["log_cbps_weights"], log_adjust["log_weight"], color="blue")
ax.set_xlabel("log10(cbps_weights)")
ax.set_ylabel("log10(weight)")
ax.set_title("Log10 CBPS weights vs. weights")
log_adjust[["log_cbps_weights", "log_weight"]].corr(method = "pearson")
Out[6]:
| log_cbps_weights | log_weight | |
|---|---|---|
| log_cbps_weights | 1.00000 | 0.99494 |
| log_weight | 0.99494 | 1.00000 |
In [7]:
# Notice how the y outcome before and after the weights is 220.67 -> 207.55, similar to R's 220.67 -> 206.8
print(adjust.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source self target unadjusted self_ci target_ci unadjusted_ci
cbps_weights 0.008 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005)
y 207.559 199.544 220.677 (203.353, 211.765) (195.473, 203.616) (216.335, 225.019)
Weights impact on outcomes (t_test):
mean_yw0 mean_yw1 mean_diff diff_ci_lower diff_ci_upper t_stat p_value n
outcome
cbps_weights 0.004 0.008 0.004 0.001 0.006 3.101 0.002 246.0
y 220.677 207.559 -13.118 -39.115 12.879 -0.994 0.321 246.0
Response rates (relative to number of respondents in sample):
y cbps_weights
n 246.0 246.0
% 100.0 100.0
Response rates (relative to notnull rows in the target):
y cbps_weights
n 246.000000 246.000000
% 96.850394 96.850394
Response rates (in the target):
y cbps_weights
n 254.0 254.0
% 100.0 100.0
In [8]:
# Just to get some sense of what the weights did to the covars:
adjust.covars().plot(library = "seaborn", dist_type = "kde")
In [9]:
# In contrast, if we were to use the original CBPS weights, we'd get this:
from copy import deepcopy
adjust2 = deepcopy(adjust)
cbps_weights = adjust2.outcomes().df.cbps_weights
adjust2.set_weights(cbps_weights)
# .covars().plot(library = "seaborn", dist_type = "kde")
In [10]:
# we can see that this worked since the weighted avg of y is now 206.8
print(adjust2.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source self target unadjusted self_ci target_ci unadjusted_ci
cbps_weights 0.007 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005)
y 206.844 199.544 220.677 (202.453, 211.235) (195.473, 203.616) (216.335, 225.019)
Weights impact on outcomes (t_test):
mean_yw0 mean_yw1 mean_diff diff_ci_lower diff_ci_upper t_stat p_value n
outcome
cbps_weights 0.004 0.007 0.003 0.002 0.005 3.690 0.000 246.0
y 220.677 206.844 -13.833 -38.112 10.447 -1.122 0.263 246.0
Response rates (relative to number of respondents in sample):
y cbps_weights
n 246.0 246.0
% 100.0 100.0
Response rates (relative to notnull rows in the target):
y cbps_weights
n 246.000000 246.000000
% 96.850394 96.850394
Response rates (in the target):
y cbps_weights
n 254.0 254.0
% 100.0 100.0
In [11]:
# And here is how the covars looked like in the original CBPS implementation from R:
# Almost identical correction as balance did
adjust2.covars().plot(library = "seaborn", dist_type = "kde")
Sessions info¶
In [12]:
session_info.show(html=False, dependencies=True)
----- balance 0.16.1 matplotlib 3.10.8 numpy 2.4.2 pandas 3.0.1 session_info v1.0.1 ----- PIL 12.1.1 anyio NA arrow 1.4.0 asttokens NA attr 25.4.0 attrs 25.4.0 babel 2.18.0 certifi 2026.01.04 charset_normalizer 3.4.4 comm 0.2.3 cycler 0.12.1 cython_runtime NA dateutil 2.9.0.post0 debugpy 1.8.20 decorator 5.2.1 defusedxml 0.7.1 executing 2.2.1 fastjsonschema NA fqdn NA idna 3.11 ipykernel 7.2.0 isoduration NA jedi 0.19.2 jinja2 3.1.6 joblib 1.5.3 json5 0.13.0 jsonpointer 3.0.0 jsonschema 4.26.0 jsonschema_specifications NA jupyter_events 0.12.0 jupyter_server 2.17.0 jupyterlab_server 2.28.0 kiwisolver 1.4.9 lark 1.3.1 markupsafe 3.0.3 matplotlib_inline 0.2.1 mpl_toolkits NA narwhals 2.16.0 nbformat 5.10.4 packaging 26.0 parso 0.8.6 patsy 1.0.2 platformdirs 4.9.2 plotly 6.5.2 prometheus_client NA prompt_toolkit 3.0.52 psutil 7.2.2 pure_eval 0.2.3 pydev_ipython NA pydevconsole NA pydevd 3.2.3 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.19.2 pyparsing 3.3.2 pythonjsonlogger NA referencing NA requests 2.32.5 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rfc3987_syntax NA rpds NA scipy 1.17.0 seaborn 0.13.2 send2trash NA six 1.17.0 sklearn 1.8.0 sphinxcontrib NA stack_data 0.6.3 statsmodels 0.14.6 threadpoolctl 3.6.0 tornado 6.5.4 traitlets 5.14.3 typing_extensions NA uri_template NA urllib3 2.6.3 wcwidth 0.6.0 webcolors NA websocket 1.9.0 yaml 6.0.3 zmq 27.1.0 zoneinfo NA ----- IPython 9.10.0 jupyter_client 8.8.0 jupyter_core 5.9.1 jupyterlab 4.5.4 notebook 7.5.3 ----- Python 3.12.12 (main, Oct 10 2025, 01:01:16) [GCC 13.3.0] Linux-6.11.0-1018-azure-x86_64-with-glibc2.39 ----- Session information updated at 2026-02-21 04:53