Comparing results of fitting CBPS between R's CBPS package and Python's balance package (using simulated data)¶

This notebook shows how we can reproduce (almost exactly) the weights produced from R's CBPS package, using the implementation in balance.

The example is based on simulated data that that was provided in the help page of the CBPS function (i.e.: ?CBPS::CBPS, you can see it here).

The R code used to create the data is available here.

Loading data and fitting CBPS using balance¶

In [1]:
import balance
import numpy as np
import pandas as pd
import session_info

from balance import Sample
INFO (2025-07-21 09:10:04,259) [__init__/<module> (line 54)]: Using balance version 0.10.0
In [2]:
target_df, sample_df = balance.datasets.load_data("sim_data_cbps")
# print(target_df.head())
print(target_df.info())
<class 'pandas.core.frame.DataFrame'>
Index: 254 entries, 1 to 498
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   X1            254 non-null    float64
 1   X2            254 non-null    float64
 2   X3            254 non-null    float64
 3   X4            254 non-null    float64
 4   cbps_weights  254 non-null    float64
 5   y             254 non-null    float64
 6   id            254 non-null    int64  
dtypes: float64(6), int64(1)
memory usage: 15.9 KB
None
In [3]:
sample = Sample.from_frame(sample_df, outcome_columns = ['y', 'cbps_weights'])
target = Sample.from_frame(target_df, outcome_columns = ['y', 'cbps_weights'])
sample_target = sample.set_target(target)
WARNING (2025-07-21 09:10:04,479) [util/guess_id_column (line 113)]: Guessed id column name id for the data
WARNING (2025-07-21 09:10:04,480) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2025-07-21 09:10:04,489) [util/_warn_of_df_dtypes_change (line 1837)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences - 
WARNING (2025-07-21 09:10:04,489) [util/_warn_of_df_dtypes_change (line 1846)]: The (old) dtypes that changed for df (before the change):
WARNING (2025-07-21 09:10:04,491) [util/_warn_of_df_dtypes_change (line 1849)]: 
id    int64
dtype: object
WARNING (2025-07-21 09:10:04,492) [util/_warn_of_df_dtypes_change (line 1850)]: The (new) dtypes saved in df (after the change):
WARNING (2025-07-21 09:10:04,493) [util/_warn_of_df_dtypes_change (line 1851)]: 
id    object
dtype: object
WARNING (2025-07-21 09:10:04,494) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2025-07-21 09:10:04,501) [util/guess_id_column (line 113)]: Guessed id column name id for the data
WARNING (2025-07-21 09:10:04,502) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2025-07-21 09:10:04,509) [util/_warn_of_df_dtypes_change (line 1837)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences - 
WARNING (2025-07-21 09:10:04,510) [util/_warn_of_df_dtypes_change (line 1846)]: The (old) dtypes that changed for df (before the change):
WARNING (2025-07-21 09:10:04,511) [util/_warn_of_df_dtypes_change (line 1849)]: 
id    int64
dtype: object
WARNING (2025-07-21 09:10:04,511) [util/_warn_of_df_dtypes_change (line 1850)]: The (new) dtypes saved in df (after the change):
WARNING (2025-07-21 09:10:04,513) [util/_warn_of_df_dtypes_change (line 1851)]: 
id    object
dtype: object
WARNING (2025-07-21 09:10:04,513) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
In [4]:
# adjust = sample_target.adjust(method = "cbps")  # the defaults of the function would not yield similar-enough results, so we need to adjust some parameters:
adjust = sample_target.adjust(method = "cbps", transformations = None, weight_trimming_mean_ratio = None)
INFO (2025-07-21 09:10:04,522) [cbps/cbps (line 411)]: Starting cbps function
INFO (2025-07-21 09:10:04,534) [cbps/cbps (line 461)]: The formula used to build the model matrix: ['X4 + X3 + X2 + X1']
INFO (2025-07-21 09:10:04,535) [cbps/cbps (line 473)]: The number of columns in the model matrix: 4
INFO (2025-07-21 09:10:04,536) [cbps/cbps (line 474)]: The number of rows in the model matrix: 500
INFO (2025-07-21 09:10:04,538) [cbps/cbps (line 543)]: Finding initial estimator for GMM optimization
INFO (2025-07-21 09:10:04,545) [cbps/cbps (line 570)]: Finding initial estimator for GMM optimization that minimizes the balance loss
INFO (2025-07-21 09:10:04,564) [cbps/cbps (line 605)]: Running GMM optimization
INFO (2025-07-21 09:10:04,587) [cbps/cbps (line 734)]: Done cbps function

Comparing results of balance and CBPS¶

In [5]:
# adjust.df.plot.scatter(x="cbps_weights", y="weight", color="blue")
adjust.df[["cbps_weights", "weight"]].corr(method = "pearson")
Out[5]:
cbps_weights weight
cbps_weights 1.000000 0.980464
weight 0.980464 1.000000
In [6]:
# adjust.df.copy().assign(log_cbps_weights=np.log(adjust.df['cbps_weights']),log_weight=np.log(adjust.df['weight'])).plot.scatter('log_cbps_weights', 'log_weight', color='blue')
adjust.df[["cbps_weights", "weight"]].apply(lambda x: np.log10(x)).corr(method = "pearson")
Out[6]:
cbps_weights weight
cbps_weights 1.00000 0.99494
weight 0.99494 1.00000
In [7]:
# Notice how the y outcome before and after the weigts is 220.67 -> 207.55, similar to R's 220.67 -> 206.8
print(adjust.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source           self   target  unadjusted             self_ci           target_ci       unadjusted_ci
cbps_weights    0.008    0.004       0.004      (0.006, 0.009)      (0.004, 0.004)      (0.004, 0.005)
y             207.559  199.544     220.677  (203.353, 211.765)  (195.473, 203.616)  (216.335, 225.019)

Response rates (relative to number of respondents in sample):
       y  cbps_weights
n  246.0         246.0
%  100.0         100.0
Response rates (relative to notnull rows in the target):
             y  cbps_weights
n  246.000000    246.000000
%   96.850394     96.850394
Response rates (in the target):
        y  cbps_weights
n  254.0         254.0
%  100.0         100.0

In [8]:
# Just to get some sense of what the weights did to the covars:
adjust.covars().plot(library = "seaborn", dist_type = "kde")
No description has been provided for this image
In [9]:
# In contrast, if we were to use the original CBPS weights, we'd get this:
from copy import deepcopy
adjust2 = deepcopy(adjust)
cbps_weights = adjust2.outcomes().df.cbps_weights
adjust2.set_weights(cbps_weights)
# .covars().plot(library = "seaborn", dist_type = "kde")
In [10]:
# we can see that this worked since the weighted avg of y is now 206.8
print(adjust2.outcomes().summary())
2 outcomes: ['y' 'cbps_weights']
Mean outcomes (with 95% confidence intervals):
source           self   target  unadjusted             self_ci           target_ci       unadjusted_ci
cbps_weights    0.007    0.004       0.004      (0.006, 0.009)      (0.004, 0.004)      (0.004, 0.005)
y             206.844  199.544     220.677  (202.453, 211.235)  (195.473, 203.616)  (216.335, 225.019)

Response rates (relative to number of respondents in sample):
       y  cbps_weights
n  246.0         246.0
%  100.0         100.0
Response rates (relative to notnull rows in the target):
             y  cbps_weights
n  246.000000    246.000000
%   96.850394     96.850394
Response rates (in the target):
        y  cbps_weights
n  254.0         254.0
%  100.0         100.0

In [11]:
# And here is how the covars looked like in the original CBPS implementation from R:
# Almost identical correcation as balance did
adjust2.covars().plot(library = "seaborn", dist_type = "kde")
No description has been provided for this image
In [ ]:
 

Sessions info¶

In [12]:
session_info.show(html=False, dependencies=True)
-----
balance             0.10.0
numpy               1.26.4
pandas              2.0.3
session_info        v1.0.1
-----
PIL                         11.3.0
anyio                       NA
arrow                       1.3.0
asttokens                   NA
attr                        25.3.0
attrs                       25.3.0
babel                       2.17.0
certifi                     2025.07.14
charset_normalizer          3.4.2
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dateutil                    2.9.0.post0
debugpy                     1.8.15
decorator                   5.2.1
defusedxml                  0.7.1
exceptiongroup              1.3.0
executing                   2.2.0
fastjsonschema              NA
fqdn                        NA
idna                        3.10
importlib_metadata          NA
importlib_resources         NA
ipfn                        NA
ipykernel                   6.29.5
isoduration                 NA
jedi                        0.19.2
jinja2                      3.1.6
joblib                      1.5.1
json5                       0.12.0
jsonpointer                 3.0.0
jsonschema                  4.25.0
jsonschema_specifications   NA
jupyter_events              0.12.0
jupyter_server              2.16.0
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
lark                        1.2.2
markupsafe                  3.0.2
matplotlib                  3.9.4
matplotlib_inline           0.1.7
mpl_toolkits                NA
narwhals                    1.47.1
nbformat                    5.10.4
overrides                   NA
packaging                   25.0
parso                       0.8.4
patsy                       1.0.1
pexpect                     4.9.0
platformdirs                4.3.8
plotly                      6.2.0
prometheus_client           NA
prompt_toolkit              3.0.51
psutil                      7.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.2.3
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pygments                    2.19.2
pyparsing                   3.2.3
pythonjsonlogger            NA
pytz                        2025.2
referencing                 NA
requests                    2.32.4
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rfc3987_syntax              NA
rpds                        NA
scipy                       1.10.1
seaborn                     0.13.2
send2trash                  NA
six                         1.17.0
sklearn                     1.2.2
sniffio                     1.3.1
sphinxcontrib               NA
stack_data                  0.6.3
statsmodels                 0.14.5
threadpoolctl               3.6.0
tornado                     6.5.1
traitlets                   5.14.3
typing_extensions           NA
uri_template                NA
urllib3                     2.5.0
wcwidth                     0.2.13
webcolors                   NA
websocket                   1.8.0
yaml                        6.0.2
zipp                        NA
zmq                         27.0.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.8.1
jupyterlab          4.4.5
notebook            7.4.4
-----
Python 3.9.23 (main, Jun  4 2025, 04:11:23) [GCC 13.3.0]
Linux-6.11.0-1018-azure-x86_64-with-glibc2.39
-----
Session information updated at 2025-07-21 09:10
/opt/hostedtoolcache/Python/3.9.23/x64/lib/python3.9/site-packages/session_info/main.py:213: UserWarning:

The '__version__' attribute is deprecated and will be removed in MarkupSafe 3.1. Use feature detection, or `importlib.metadata.version("markupsafe")`, instead.

In [ ]: