Comparing results of fitting CBPS between R's CBPS
package and Python's balance
package (using simulated data)¶
This notebook shows how we can reproduce (almost exactly) the weights produced from R's CBPS package, using the implementation in balance.
The example is based on simulated data that that was provided in the help page of the CBPS function (i.e.: ?CBPS::CBPS
, you can see it here).
The R code used to create the data is available here.
Loading data and fitting CBPS using balance
¶
In [1]:
import balance
import numpy as np
import pandas as pd
import session_info
from balance import Sample
INFO (2025-03-20 17:36:07,754) [__init__/<module> (line 54)]: Using balance version 0.10.0
In [2]:
target_df, sample_df = balance.datasets.load_data("sim_data_cbps")
# print(target_df.head())
print(target_df.info())
<class 'pandas.core.frame.DataFrame'> Index: 254 entries, 1 to 498 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 X1 254 non-null float64 1 X2 254 non-null float64 2 X3 254 non-null float64 3 X4 254 non-null float64 4 cbps_weights 254 non-null float64 5 y 254 non-null float64 6 id 254 non-null int64 dtypes: float64(6), int64(1) memory usage: 15.9 KB None
In [3]:
sample = Sample.from_frame(sample_df, outcome_columns = ['y', 'cbps_weights'])
target = Sample.from_frame(target_df, outcome_columns = ['y', 'cbps_weights'])
sample_target = sample.set_target(target)
WARNING (2025-03-20 17:36:07,979) [util/guess_id_column (line 113)]: Guessed id column name id for the data
WARNING (2025-03-20 17:36:07,980) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2025-03-20 17:36:07,989) [util/_warn_of_df_dtypes_change (line 1837)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2025-03-20 17:36:07,990) [util/_warn_of_df_dtypes_change (line 1846)]: The (old) dtypes that changed for df (before the change):
WARNING (2025-03-20 17:36:07,991) [util/_warn_of_df_dtypes_change (line 1849)]: id int64 dtype: object
WARNING (2025-03-20 17:36:07,991) [util/_warn_of_df_dtypes_change (line 1850)]: The (new) dtypes saved in df (after the change):
WARNING (2025-03-20 17:36:07,993) [util/_warn_of_df_dtypes_change (line 1851)]: id object dtype: object
WARNING (2025-03-20 17:36:07,993) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2025-03-20 17:36:08,001) [util/guess_id_column (line 113)]: Guessed id column name id for the data
WARNING (2025-03-20 17:36:08,002) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2025-03-20 17:36:08,010) [util/_warn_of_df_dtypes_change (line 1837)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2025-03-20 17:36:08,010) [util/_warn_of_df_dtypes_change (line 1846)]: The (old) dtypes that changed for df (before the change):
WARNING (2025-03-20 17:36:08,012) [util/_warn_of_df_dtypes_change (line 1849)]: id int64 dtype: object
WARNING (2025-03-20 17:36:08,013) [util/_warn_of_df_dtypes_change (line 1850)]: The (new) dtypes saved in df (after the change):
WARNING (2025-03-20 17:36:08,014) [util/_warn_of_df_dtypes_change (line 1851)]: id object dtype: object
WARNING (2025-03-20 17:36:08,014) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
In [4]:
# adjust = sample_target.adjust(method = "cbps") # the defaults of the function would not yield similar-enough results, so we need to adjust some parameters:
adjust = sample_target.adjust(method = "cbps", transformations = None, weight_trimming_mean_ratio = None)
INFO (2025-03-20 17:36:08,023) [cbps/cbps (line 411)]: Starting cbps function
INFO (2025-03-20 17:36:08,035) [cbps/cbps (line 461)]: The formula used to build the model matrix: ['X4 + X3 + X2 + X1']
INFO (2025-03-20 17:36:08,036) [cbps/cbps (line 473)]: The number of columns in the model matrix: 4
INFO (2025-03-20 17:36:08,036) [cbps/cbps (line 474)]: The number of rows in the model matrix: 500
INFO (2025-03-20 17:36:08,039) [cbps/cbps (line 543)]: Finding initial estimator for GMM optimization
INFO (2025-03-20 17:36:08,046) [cbps/cbps (line 570)]: Finding initial estimator for GMM optimization that minimizes the balance loss
INFO (2025-03-20 17:36:08,068) [cbps/cbps (line 605)]: Running GMM optimization
INFO (2025-03-20 17:36:08,093) [cbps/cbps (line 734)]: Done cbps function
Comparing results of balance
and CBPS
¶
In [5]:
# adjust.df.plot.scatter(x="cbps_weights", y="weight", color="blue")
adjust.df[["cbps_weights", "weight"]].corr(method = "pearson")
Out[5]:
cbps_weights | weight | |
---|---|---|
cbps_weights | 1.000000 | 0.980464 |
weight | 0.980464 | 1.000000 |
In [6]:
# adjust.df.copy().assign(log_cbps_weights=np.log(adjust.df['cbps_weights']),log_weight=np.log(adjust.df['weight'])).plot.scatter('log_cbps_weights', 'log_weight', color='blue')
adjust.df[["cbps_weights", "weight"]].apply(lambda x: np.log10(x)).corr(method = "pearson")
Out[6]:
cbps_weights | weight | |
---|---|---|
cbps_weights | 1.00000 | 0.99494 |
weight | 0.99494 | 1.00000 |
In [7]:
# Notice how the y outcome before and after the weigts is 220.67 -> 207.55, similar to R's 220.67 -> 206.8
print(adjust.outcomes().summary())
2 outcomes: ['y' 'cbps_weights'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci cbps_weights 0.008 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005) y 207.559 199.544 220.677 (203.353, 211.765) (195.473, 203.616) (216.335, 225.019) Response rates (relative to number of respondents in sample): y cbps_weights n 246.0 246.0 % 100.0 100.0 Response rates (relative to notnull rows in the target): y cbps_weights n 246.000000 246.000000 % 96.850394 96.850394 Response rates (in the target): y cbps_weights n 254.0 254.0 % 100.0 100.0
In [8]:
# Just to get some sense of what the weights did to the covars:
adjust.covars().plot(library = "seaborn", dist_type = "kde")
In [9]:
# In contrast, if we were to use the original CBPS weights, we'd get this:
from copy import deepcopy
adjust2 = deepcopy(adjust)
cbps_weights = adjust2.outcomes().df.cbps_weights
adjust2.set_weights(cbps_weights)
# .covars().plot(library = "seaborn", dist_type = "kde")
In [10]:
# we can see that this worked since the weighted avg of y is now 206.8
print(adjust2.outcomes().summary())
2 outcomes: ['y' 'cbps_weights'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci cbps_weights 0.007 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005) y 206.844 199.544 220.677 (202.453, 211.235) (195.473, 203.616) (216.335, 225.019) Response rates (relative to number of respondents in sample): y cbps_weights n 246.0 246.0 % 100.0 100.0 Response rates (relative to notnull rows in the target): y cbps_weights n 246.000000 246.000000 % 96.850394 96.850394 Response rates (in the target): y cbps_weights n 254.0 254.0 % 100.0 100.0
In [11]:
# And here is how the covars looked like in the original CBPS implementation from R:
# Almost identical correcation as balance did
adjust2.covars().plot(library = "seaborn", dist_type = "kde")
In [ ]:
Sessions info¶
In [12]:
session_info.show(html=False, dependencies=True)
----- balance 0.10.0 numpy 1.26.4 pandas 2.0.3 session_info 1.0.0 ----- PIL 11.1.0 anyio NA arrow 1.3.0 asttokens NA attr 25.3.0 attrs 25.3.0 babel 2.17.0 certifi 2025.01.31 charset_normalizer 3.4.1 comm 0.2.2 cycler 0.12.1 cython_runtime NA dateutil 2.9.0.post0 debugpy 1.8.13 decorator 5.2.1 defusedxml 0.7.1 exceptiongroup 1.2.2 executing 2.2.0 fastjsonschema NA fqdn NA idna 3.10 importlib_metadata NA importlib_resources NA ipfn NA ipykernel 6.29.5 isoduration NA jedi 0.19.2 jinja2 3.1.6 joblib 1.4.2 json5 0.10.0 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.12.0 jupyter_server 2.15.0 jupyterlab_server 2.27.3 kiwisolver 1.4.7 markupsafe 3.0.2 matplotlib 3.9.4 matplotlib_inline 0.1.7 mpl_toolkits NA narwhals 1.31.0 nbformat 5.10.4 overrides NA packaging 24.2 parso 0.8.4 patsy 1.0.1 pexpect 4.9.0 platformdirs 4.3.7 plotly 6.0.1 prometheus_client NA prompt_toolkit 3.0.50 psutil 7.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 pydev_ipython NA pydevconsole NA pydevd 3.2.3 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.19.1 pyparsing 3.2.1 pythonjsonlogger NA pytz 2025.1 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.10.1 seaborn 0.13.2 send2trash NA six 1.17.0 sklearn 1.2.2 sniffio 1.3.1 sphinxcontrib NA stack_data 0.6.3 statsmodels 0.14.4 threadpoolctl 3.6.0 tornado 6.4.2 traitlets 5.14.3 typing_extensions NA uri_template NA urllib3 2.3.0 wcwidth 0.2.13 webcolors NA websocket 1.8.0 yaml 6.0.2 zipp NA zmq 26.3.0 zoneinfo NA ----- IPython 8.18.1 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.3.6 notebook 7.3.3 ----- Python 3.9.21 (main, Dec 12 2024, 19:08:08) [GCC 13.2.0] Linux-6.8.0-1021-azure-x86_64-with-glibc2.39 ----- Session information updated at 2025-03-20 17:36
/opt/hostedtoolcache/Python/3.9.21/x64/lib/python3.9/site-packages/session_info/main.py:213: UserWarning: The '__version__' attribute is deprecated and will be removed in MarkupSafe 3.1. Use feature detection, or `importlib.metadata.version("markupsafe")`, instead.
In [ ]: