Comparing results of fitting CBPS between R's CBPS
package and Python's balance
package (using simulated data)¶
This notebook shows how we can reproduce (almost exactly) the weights produced from R's CBPS package, using the implementation in balance.
The example is based on simulated data that that was provided in the help page of the CBPS function (i.e.: ?CBPS::CBPS
, you can see it here).
The R code used to create the data is available here.
Loading data and fitting CBPS using balance
¶
In [1]:
import balance
import numpy as np
import pandas as pd
import session_info
from balance import Sample
INFO (2024-11-21 05:27:22,145) [__init__/<module> (line 54)]: Using balance version 0.9.1
In [2]:
target_df, sample_df = balance.datasets.load_data("sim_data_cbps")
# print(target_df.head())
print(target_df.info())
<class 'pandas.core.frame.DataFrame'> Int64Index: 254 entries, 1 to 498 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 X1 254 non-null float64 1 X2 254 non-null float64 2 X3 254 non-null float64 3 X4 254 non-null float64 4 cbps_weights 254 non-null float64 5 y 254 non-null float64 6 id 254 non-null int64 dtypes: float64(6), int64(1) memory usage: 15.9 KB None
/home/runner/.local/lib/python3.10/site-packages/balance/datasets/__init__.py:160: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only. /home/runner/.local/lib/python3.10/site-packages/balance/datasets/__init__.py:161: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
In [3]:
sample = Sample.from_frame(sample_df, outcome_columns = ['y', 'cbps_weights'])
target = Sample.from_frame(target_df, outcome_columns = ['y', 'cbps_weights'])
sample_target = sample.set_target(target)
WARNING (2024-11-21 05:27:22,303) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-21 05:27:22,304) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2024-11-21 05:27:22,311) [util/_warn_of_df_dtypes_change (line 1839)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2024-11-21 05:27:22,311) [util/_warn_of_df_dtypes_change (line 1850)]: The (old) dtypes that changed for df (before the change):
WARNING (2024-11-21 05:27:22,312) [util/_warn_of_df_dtypes_change (line 1853)]: id int64 dtype: object
WARNING (2024-11-21 05:27:22,313) [util/_warn_of_df_dtypes_change (line 1854)]: The (new) dtypes saved in df (after the change):
WARNING (2024-11-21 05:27:22,314) [util/_warn_of_df_dtypes_change (line 1855)]: id object dtype: object
WARNING (2024-11-21 05:27:22,314) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2024-11-21 05:27:22,322) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-21 05:27:22,322) [sample_class/from_frame (line 190)]: Casting id column to string
WARNING (2024-11-21 05:27:22,328) [util/_warn_of_df_dtypes_change (line 1839)]: The dtypes of sample._df were changed from the original dtypes of the input df, here are the differences -
WARNING (2024-11-21 05:27:22,329) [util/_warn_of_df_dtypes_change (line 1850)]: The (old) dtypes that changed for df (before the change):
WARNING (2024-11-21 05:27:22,330) [util/_warn_of_df_dtypes_change (line 1853)]: id int64 dtype: object
WARNING (2024-11-21 05:27:22,330) [util/_warn_of_df_dtypes_change (line 1854)]: The (new) dtypes saved in df (after the change):
WARNING (2024-11-21 05:27:22,331) [util/_warn_of_df_dtypes_change (line 1855)]: id object dtype: object
WARNING (2024-11-21 05:27:22,331) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
In [4]:
# adjust = sample_target.adjust(method = "cbps") # the defaults of the function would not yield similar-enough results, so we need to adjust some parameters:
adjust = sample_target.adjust(method = "cbps", transformations = None, weight_trimming_mean_ratio = None)
INFO (2024-11-21 05:27:22,339) [cbps/cbps (line 411)]: Starting cbps function
INFO (2024-11-21 05:27:22,350) [cbps/cbps (line 462)]: The formula used to build the model matrix: ['X4 + X3 + X2 + X1']
INFO (2024-11-21 05:27:22,350) [cbps/cbps (line 474)]: The number of columns in the model matrix: 4
INFO (2024-11-21 05:27:22,351) [cbps/cbps (line 475)]: The number of rows in the model matrix: 500
INFO (2024-11-21 05:27:22,353) [cbps/cbps (line 537)]: Finding initial estimator for GMM optimization
INFO (2024-11-21 05:27:22,359) [cbps/cbps (line 564)]: Finding initial estimator for GMM optimization that minimizes the balance loss
INFO (2024-11-21 05:27:22,380) [cbps/cbps (line 599)]: Running GMM optimization
INFO (2024-11-21 05:27:22,405) [cbps/cbps (line 730)]: Done cbps function
Comparing results of balance
and CBPS
¶
In [5]:
# adjust.df.plot.scatter(x="cbps_weights", y="weight", color="blue")
adjust.df[["cbps_weights", "weight"]].corr(method = "pearson")
Out[5]:
cbps_weights | weight | |
---|---|---|
cbps_weights | 1.000000 | 0.980464 |
weight | 0.980464 | 1.000000 |
In [6]:
# adjust.df.copy().assign(log_cbps_weights=np.log(adjust.df['cbps_weights']),log_weight=np.log(adjust.df['weight'])).plot.scatter('log_cbps_weights', 'log_weight', color='blue')
adjust.df[["cbps_weights", "weight"]].apply(lambda x: np.log10(x)).corr(method = "pearson")
Out[6]:
cbps_weights | weight | |
---|---|---|
cbps_weights | 1.00000 | 0.99494 |
weight | 0.99494 | 1.00000 |
In [7]:
# Notice how the y outcome before and after the weigts is 220.67 -> 207.55, similar to R's 220.67 -> 206.8
print(adjust.outcomes().summary())
2 outcomes: ['y' 'cbps_weights'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci cbps_weights 0.008 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005) y 207.559 199.544 220.677 (203.353, 211.765) (195.473, 203.616) (216.335, 225.019) Response rates (relative to number of respondents in sample): y cbps_weights n 246.0 246.0 % 100.0 100.0 Response rates (relative to notnull rows in the target): y cbps_weights n 246.000000 246.000000 % 96.850394 96.850394 Response rates (in the target): y cbps_weights n 254.0 254.0 % 100.0 100.0
In [8]:
# Just to get some sense of what the weights did to the covars:
adjust.covars().plot(library = "seaborn", dist_type = "kde")
In [9]:
# In contrast, if we were to use the original CBPS weights, we'd get this:
from copy import deepcopy
adjust2 = deepcopy(adjust)
cbps_weights = adjust2.outcomes().df.cbps_weights
adjust2.set_weights(cbps_weights)
# .covars().plot(library = "seaborn", dist_type = "kde")
In [10]:
# we can see that this worked since the weighted avg of y is now 206.8
print(adjust2.outcomes().summary())
2 outcomes: ['y' 'cbps_weights'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci cbps_weights 0.007 0.004 0.004 (0.006, 0.009) (0.004, 0.004) (0.004, 0.005) y 206.844 199.544 220.677 (202.453, 211.235) (195.473, 203.616) (216.335, 225.019) Response rates (relative to number of respondents in sample): y cbps_weights n 246.0 246.0 % 100.0 100.0 Response rates (relative to notnull rows in the target): y cbps_weights n 246.000000 246.000000 % 96.850394 96.850394 Response rates (in the target): y cbps_weights n 254.0 254.0 % 100.0 100.0
In [11]:
# And here is how the covars looked like in the original CBPS implementation from R:
# Almost identical correcation as balance did
adjust2.covars().plot(library = "seaborn", dist_type = "kde")
In [ ]:
Sessions info¶
In [12]:
session_info.show(html=False, dependencies=True)
----- balance 0.9.1 numpy 1.24.4 pandas 1.4.3 session_info 1.0.0 ----- PIL 11.0.0 anyio NA apport_python_hook NA argcomplete NA arrow 1.3.0 asttokens NA attr 24.2.0 attrs 24.2.0 babel 2.16.0 beta_ufunc NA binom_ufunc NA certifi 2020.06.20 chardet 4.0.0 charset_normalizer 3.4.0 colorama 0.4.4 comm 0.2.2 coxnet NA cvcompute NA cvelnet NA cvfishnet NA cvglmnet NA cvglmnetCoef NA cvglmnetPredict NA cvlognet NA cvmrelnet NA cvmultnet NA cycler 0.12.1 cython_runtime NA dateutil 2.9.0.post0 debugpy 1.8.8 decorator 5.1.1 defusedxml 0.7.1 elnet NA exceptiongroup 1.2.2 executing 2.1.0 fastjsonschema NA fishnet NA fqdn NA gi 3.42.1 gio NA glib NA glmnet NA glmnetCoef NA glmnetControl NA glmnetPredict NA glmnetSet NA glmnet_python NA gobject NA gtk NA hypergeom_ufunc NA idna 3.3 ipfn NA ipykernel 6.29.5 isoduration NA jedi 0.19.2 jinja2 3.1.4 joblib 1.4.2 json5 0.9.28 jsonpointer 2.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.7 loadGlmLib NA lognet NA markupsafe 2.0.1 matplotlib 3.9.2 matplotlib_inline 0.1.7 mpl_toolkits NA mrelnet NA nbformat 5.10.4 nbinom_ufunc NA ncf_ufunc NA overrides NA packaging 24.2 parso 0.8.4 patsy 1.0.1 platformdirs 4.3.6 plotly 5.24.1 prometheus_client NA prompt_toolkit 3.0.48 psutil 6.1.0 pure_eval 0.2.3 pydev_ipython NA pydevconsole NA pydevd 3.2.2 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.18.0 pyparsing 2.4.7 pythonjsonlogger NA pytz 2022.1 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.9.1 seaborn 0.13.0 send2trash NA sitecustomize NA six 1.16.0 sklearn 1.5.2 sniffio 1.3.1 sphinxcontrib NA stack_data 0.6.3 statsmodels 0.14.4 threadpoolctl 3.5.0 tornado 6.4.1 traitlets 5.14.3 typing_extensions NA uri_template NA urllib3 1.26.5 wcwidth 0.2.13 webcolors NA websocket 1.8.0 wtmean NA yaml 5.4.1 zmq 26.2.0 zoneinfo NA zope NA ----- IPython 8.29.0 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.2.6 notebook 7.2.2 ----- Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] Linux-6.5.0-1025-azure-x86_64-with-glibc2.35 ----- Session information updated at 2024-11-21 05:27
In [ ]: