Comparing results of fitting CBPS between R's CBPS package and Python's balance package (using simulated data)¶

This notebook shows how we can reproduce (almost exactly) the weights produced from R's CBPS package, using the implementation in balance.

The example is based on simulated data that that was provided in the help page of the CBPS function (i.e.: ?CBPS::CBPS, you can see it here).

The R code used to create the data is available here.

Loading data and fitting CBPS using balance¶

In [1]:
import balance
import numpy as np
import pandas as pd
import session_info

from balance import Sample
INFO (2026-02-21 04:52:54,141) [__init__/<module> (line 72)]: Using balance version 0.16.1
balance (Version 0.16.1) loaded:
    📖 Documentation: https://import-balance.org/
    🛠️ Help / Issues: https://github.com/facebookresearch/balance/issues/
    📄 Citation:
        Sarig, T., Galili, T., & Eilat, R. (2023).
        balance - a Python package for balancing biased data samples.
        https://arxiv.org/abs/2307.06024

    Tip: You can view this message anytime with balance.help()

In [2]:
target_df, sample_df = balance.datasets.load_data("sim_data_cbps")
# print(target_df.head())
print(target_df.info())
<class 'pandas.DataFrame'>
Index: 254 entries, 1 to 498
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   X1            254 non-null    float64
 1   X2            254 non-null    float64
 2   X3            254 non-null    float64
 3   X4            254 non-null    float64
 4   cbps_weights  254 non-null    float64
 5   y             254 non-null    float64
 6   id            254 non-null    int64  
dtypes: float64(6), int64(1)
memory usage: 15.9 KB
None
In [3]:
sample = Sample.from_frame(sample_df, outcome_columns = ['y', 'cbps_weights'])
target = Sample.from_frame(target_df, outcome_columns = ['y', 'cbps_weights'])
sample_target = sample.set_target(target)
WARNING (2026-02-21 04:52:54,367) [input_validation/guess_id_column (line 337)]: Guessed id column name id for the data
WARNING (2026-02-21 04:52:54,368) [sample_class/from_frame (line 469)]: Casting id column to string
WARNING (2026-02-21 04:52:54,379) [pandas_utils/_warn_of_df_dtypes_change (line 514)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences - 
WARNING (2026-02-21 04:52:54,380) [pandas_utils/_warn_of_df_dtypes_change (line 525)]: The (old) dtypes that changed for df (before the change):
WARNING (2026-02-21 04:52:54,381) [pandas_utils/_warn_of_df_dtypes_change (line 528)]: 
id    int64
dtype: object
WARNING (2026-02-21 04:52:54,381) [pandas_utils/_warn_of_df_dtypes_change (line 529)]: The (new) dtypes saved in df (after the change):
WARNING (2026-02-21 04:52:54,383) [pandas_utils/_warn_of_df_dtypes_change (line 530)]: 
id    str
dtype: object
WARNING (2026-02-21 04:52:54,383) [sample_class/from_frame (line 549)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2026-02-21 04:52:54,394) [input_validation/guess_id_column (line 337)]: Guessed id column name id for the data
WARNING (2026-02-21 04:52:54,395) [sample_class/from_frame (line 469)]: Casting id column to string
WARNING (2026-02-21 04:52:54,405) [pandas_utils/_warn_of_df_dtypes_change (line 514)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences - 
WARNING (2026-02-21 04:52:54,405) [pandas_utils/_warn_of_df_dtypes_change (line 525)]: The (old) dtypes that changed for df (before the change):
WARNING (2026-02-21 04:52:54,407) [pandas_utils/_warn_of_df_dtypes_change (line 528)]: 
id    int64
dtype: object
WARNING (2026-02-21 04:52:54,407) [pandas_utils/_warn_of_df_dtypes_change (line 529)]: The (new) dtypes saved in df (after the change):
WARNING (2026-02-21 04:52:54,409) [pandas_utils/_warn_of_df_dtypes_change (line 530)]: 
id    str
dtype: object
WARNING (2026-02-21 04:52:54,409) [sample_class/from_frame (line 549)]: No weights passed. Adding a 'weight' column and setting all values to 1
In [4]:
# adjust = sample_target.adjust(method = "cbps")  # the defaults of the function would not yield similar-enough results, so we need to adjust some parameters:
adjust = sample_target.adjust(method = "cbps", transformations = None, weight_trimming_mean_ratio = None)
INFO (2026-02-21 04:52:54,418) [cbps/cbps (line 537)]: Starting cbps function
INFO (2026-02-21 04:52:54,431) [cbps/cbps (line 588)]: The formula used to build the model matrix: ['X4 + X3 + X2 + X1']
INFO (2026-02-21 04:52:54,432) [cbps/cbps (line 599)]: The number of columns in the model matrix: 4
INFO (2026-02-21 04:52:54,433) [cbps/cbps (line 600)]: The number of rows in the model matrix: 500
INFO (2026-02-21 04:52:54,435) [cbps/cbps (line 669)]: Finding initial estimator for GMM optimization
INFO (2026-02-21 04:52:54,443) [cbps/cbps (line 696)]: Finding initial estimator for GMM optimization that minimizes the balance loss
INFO (2026-02-21 04:52:54,936) [cbps/cbps (line 732)]: Running GMM optimization
INFO (2026-02-21 04:52:55,414) [cbps/cbps (line 859)]: Done cbps function

Comparing results of balance and CBPS¶

In [5]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.scatter(adjust.df["cbps_weights"], adjust.df["weight"], color="blue")
ax.set_xlabel("cbps_weights")
ax.set_ylabel("weight")
ax.set_title("CBPS weights vs. weights")
adjust.df[["cbps_weights", "weight"]].corr(method = "pearson")
Out[5]:
cbps_weights weight
cbps_weights 1.000000 0.980463
weight 0.980463 1.000000
No description has been provided for this image
In [6]:
log_adjust = adjust.df.loc[(adjust.df["cbps_weights"] > 0) & (adjust.df["weight"] > 0), ["cbps_weights", "weight"]].copy()
log_adjust["log_cbps_weights"] = np.log10(log_adjust["cbps_weights"])
log_adjust["log_weight"] = np.log10(log_adjust["weight"])
fig, ax = plt.subplots()
ax.scatter(log_adjust["log_cbps_weights"], log_adjust["log_weight"], color="blue")
ax.set_xlabel("log10(cbps_weights)")
ax.set_ylabel("log10(weight)")
ax.set_title("Log10 CBPS weights vs. weights")
log_adjust[["log_cbps_weights", "log_weight"]].corr(method = "pearson")
Out[6]:
log_cbps_weights log_weight
log_cbps_weights 1.00000 0.99494
log_weight 0.99494 1.00000
No description has been provided for this image
In [7]:
# Notice how the y outcome before and after the weights is 220.67 -> 207.55, similar to R's 220.67 -> 206.8
print(adjust.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source           self   target  unadjusted             self_ci           target_ci       unadjusted_ci
cbps_weights    0.008    0.004       0.004      (0.006, 0.009)      (0.004, 0.004)      (0.004, 0.005)
y             207.559  199.544     220.677  (203.353, 211.765)  (195.473, 203.616)  (216.335, 225.019)

Weights impact on outcomes (t_test):
              mean_yw0  mean_yw1  mean_diff  diff_ci_lower  diff_ci_upper  t_stat  p_value      n
outcome                                                                                          
cbps_weights     0.004     0.008      0.004          0.001          0.006   3.101    0.002  246.0
y              220.677   207.559    -13.118        -39.115         12.879  -0.994    0.321  246.0

Response rates (relative to number of respondents in sample):
       y  cbps_weights
n  246.0         246.0
%  100.0         100.0
Response rates (relative to notnull rows in the target):
             y  cbps_weights
n  246.000000    246.000000
%   96.850394     96.850394
Response rates (in the target):
        y  cbps_weights
n  254.0         254.0
%  100.0         100.0

In [8]:
# Just to get some sense of what the weights did to the covars:
adjust.covars().plot(library = "seaborn", dist_type = "kde")
No description has been provided for this image
In [9]:
# In contrast, if we were to use the original CBPS weights, we'd get this:
from copy import deepcopy
adjust2 = deepcopy(adjust)
cbps_weights = adjust2.outcomes().df.cbps_weights
adjust2.set_weights(cbps_weights)
# .covars().plot(library = "seaborn", dist_type = "kde")
In [10]:
# we can see that this worked since the weighted avg of y is now 206.8
print(adjust2.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source           self   target  unadjusted             self_ci           target_ci       unadjusted_ci
cbps_weights    0.007    0.004       0.004      (0.006, 0.009)      (0.004, 0.004)      (0.004, 0.005)
y             206.844  199.544     220.677  (202.453, 211.235)  (195.473, 203.616)  (216.335, 225.019)

Weights impact on outcomes (t_test):
              mean_yw0  mean_yw1  mean_diff  diff_ci_lower  diff_ci_upper  t_stat  p_value      n
outcome                                                                                          
cbps_weights     0.004     0.007      0.003          0.002          0.005   3.690    0.000  246.0
y              220.677   206.844    -13.833        -38.112         10.447  -1.122    0.263  246.0

Response rates (relative to number of respondents in sample):
       y  cbps_weights
n  246.0         246.0
%  100.0         100.0
Response rates (relative to notnull rows in the target):
             y  cbps_weights
n  246.000000    246.000000
%   96.850394     96.850394
Response rates (in the target):
        y  cbps_weights
n  254.0         254.0
%  100.0         100.0

In [11]:
# And here is how the covars looked like in the original CBPS implementation from R:
# Almost identical correction as balance did
adjust2.covars().plot(library = "seaborn", dist_type = "kde")
No description has been provided for this image

Sessions info¶

In [12]:
session_info.show(html=False, dependencies=True)
-----
balance             0.16.1
matplotlib          3.10.8
numpy               2.4.2
pandas              3.0.1
session_info        v1.0.1
-----
PIL                         12.1.1
anyio                       NA
arrow                       1.4.0
asttokens                   NA
attr                        25.4.0
attrs                       25.4.0
babel                       2.18.0
certifi                     2026.01.04
charset_normalizer          3.4.4
comm                        0.2.3
cycler                      0.12.1
cython_runtime              NA
dateutil                    2.9.0.post0
debugpy                     1.8.20
decorator                   5.2.1
defusedxml                  0.7.1
executing                   2.2.1
fastjsonschema              NA
fqdn                        NA
idna                        3.11
ipykernel                   7.2.0
isoduration                 NA
jedi                        0.19.2
jinja2                      3.1.6
joblib                      1.5.3
json5                       0.13.0
jsonpointer                 3.0.0
jsonschema                  4.26.0
jsonschema_specifications   NA
jupyter_events              0.12.0
jupyter_server              2.17.0
jupyterlab_server           2.28.0
kiwisolver                  1.4.9
lark                        1.3.1
markupsafe                  3.0.3
matplotlib_inline           0.2.1
mpl_toolkits                NA
narwhals                    2.16.0
nbformat                    5.10.4
packaging                   26.0
parso                       0.8.6
patsy                       1.0.2
platformdirs                4.9.2
plotly                      6.5.2
prometheus_client           NA
prompt_toolkit              3.0.52
psutil                      7.2.2
pure_eval                   0.2.3
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.2.3
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pygments                    2.19.2
pyparsing                   3.3.2
pythonjsonlogger            NA
referencing                 NA
requests                    2.32.5
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rfc3987_syntax              NA
rpds                        NA
scipy                       1.17.0
seaborn                     0.13.2
send2trash                  NA
six                         1.17.0
sklearn                     1.8.0
sphinxcontrib               NA
stack_data                  0.6.3
statsmodels                 0.14.6
threadpoolctl               3.6.0
tornado                     6.5.4
traitlets                   5.14.3
typing_extensions           NA
uri_template                NA
urllib3                     2.6.3
wcwidth                     0.6.0
webcolors                   NA
websocket                   1.9.0
yaml                        6.0.3
zmq                         27.1.0
zoneinfo                    NA
-----
IPython             9.10.0
jupyter_client      8.8.0
jupyter_core        5.9.1
jupyterlab          4.5.4
notebook            7.5.3
-----
Python 3.12.12 (main, Oct 10 2025, 01:01:16) [GCC 13.3.0]
Linux-6.11.0-1018-azure-x86_64-with-glibc2.39
-----
Session information updated at 2026-02-21 04:53