balance Quickstart (CBPS): Analyzing and adjusting the bias on a simulated toy dataset¶
'balance' is a Python package that is maintained and released by the Core Data Science Tel-Aviv team in Meta. 'balance' performs and evaluates bias reduction by weighting for a broad set of experimental and observational use cases.
Although balance is written in Python, you don't need a deep Python understanding to use it. In fact, you can just use this notebook, load your data, change some variables and re-run the notebook and produce your own weights!
This quickstart demonstrates re-weighting specific simulated data, but if you have a different usecase or want more comprehensive documentation, you can check out the comprehensive balance tutorial.
Analysis¶
There are four main steps to analysis with balance:
- load data
- check diagnostics before adjustment
- perform adjustment + check diagnostics
- output results
Let's dive right in!
Example dataset¶
The following is a toy simulated dataset.
from balance import load_data
INFO (2024-11-21 05:26:33,149) [__init__/<module> (line 54)]: Using balance version 0.9.1
target_df, sample_df = load_data()
print("target_df: \n", target_df.head())
print("sample_df: \n", sample_df.head())
target_df: id gender age_group income happiness 0 100000 Male 45+ 10.183951 61.706333 1 100001 Male 45+ 6.036858 79.123670 2 100002 Male 35-44 5.226629 44.206949 3 100003 NaN 45+ 5.752147 83.985716 4 100004 NaN 25-34 4.837484 49.339713 sample_df: id gender age_group income happiness 0 0 Male 25-34 6.428659 26.043029 1 1 Female 18-24 9.940280 66.885485 2 2 Male 18-24 2.673623 37.091922 3 3 NaN 18-24 10.550308 49.394050 4 4 NaN 18-24 2.689994 72.304208
target_df.head().round(2).to_dict()
# sample_df.shape
{'id': {0: '100000', 1: '100001', 2: '100002', 3: '100003', 4: '100004'}, 'gender': {0: 'Male', 1: 'Male', 2: 'Male', 3: nan, 4: nan}, 'age_group': {0: '45+', 1: '45+', 2: '35-44', 3: '45+', 4: '25-34'}, 'income': {0: 10.18, 1: 6.04, 2: 5.23, 3: 5.75, 4: 4.84}, 'happiness': {0: 61.71, 1: 79.12, 2: 44.21, 3: 83.99, 4: 49.34}}
In practice, one can use pandas loading function(such as read_csv()
) to import data into the DataFrame objects sample_df
and target_df
.
Load data into a Sample object¶
The first thing to do is to import the Sample
class from balance. All of the data we're going to be working with, sample or population, will be stored in objects of the Sample
class.
from balance import Sample
Using the Sample class, we can fill it with a "sample" we want to adjust, and also a "target" we want to adjust towards.
We turn the two input pandas DataFrame objects we created (or loaded) into a balance.Sample objects, by using the .from_frame()
sample = Sample.from_frame(sample_df, outcome_columns=["happiness"])
target = Sample.from_frame(target_df, outcome_columns=["happiness"])
WARNING (2024-11-21 05:26:33,342) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-21 05:26:33,349) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2024-11-21 05:26:33,358) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-21 05:26:33,370) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
If we use the .df
property call, we can see the DataFrame stored in sample. We can see how we have a new weight column that was added (it will all have 1s) in the importing of the DataFrames into a balance.Sample
object.
sample.df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1000 entries, 0 to 999 Data columns (total 6 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 id 1000 non-null object 1 gender 912 non-null object 2 age_group 1000 non-null object 3 income 1000 non-null float64 4 happiness 1000 non-null float64 5 weight 1000 non-null int64 dtypes: float64(2), int64(1), object(3) memory usage: 47.0+ KB
We can get a quick overview text of each Sample object, but just calling it.
Let's take a look at what this produces:
sample
(balance.sample_class.Sample) balance Sample object 1000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness
target
(balance.sample_class.Sample) balance Sample object 10000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness
Next, we combine the sample object with the target object. This is what will allow us to adjust the sample to the target.
sample_with_target = sample.set_target(target)
Looking on sample_with_target
now, it has the target atteched:
sample_with_target
(balance.sample_class.Sample) balance Sample object with target set 1000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness target: balance Sample object 10000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness 3 common variables: gender,age_group,income
Pre-Adjustment Diagnostics¶
We can use .covars()
and then followup with .mean()
and .plot()
(barplots and qqplots) to get some basic diagnostics on what we got.
We can see how:
- The proportion of missing values in gender is similar in sample and target.
- We have younger people in the sample as compared to the target.
- We have more females than males in the sample, as compared to around 50-50 split for the (non NA) target.
- Income is more right skewed in the target as compared to the sample.
print(sample_with_target.covars().mean().T)
source self target _is_na_gender[T.True] 0.088000 0.089800 age_group[T.25-34] 0.300000 0.297400 age_group[T.35-44] 0.156000 0.299200 age_group[T.45+] 0.053000 0.206300 gender[Female] 0.268000 0.455100 gender[Male] 0.644000 0.455100 gender[_NA] 0.088000 0.089800 income 6.297302 12.737608
print(sample_with_target.covars().asmd().T)
source self age_group[T.25-34] 0.005688 age_group[T.35-44] 0.312711 age_group[T.45+] 0.378828 gender[Female] 0.375699 gender[Male] 0.379314 gender[_NA] 0.006296 income 0.494217 mean(asmd) 0.326799
print(sample_with_target.covars().asmd(aggregate_by_main_covar = True).T)
source self age_group 0.232409 gender 0.253769 income 0.494217 mean(asmd) 0.326799
sample_with_target.covars().plot()
Adjusting Sample to Population (ipw and cbps)¶
Next, we adjust the sample to the target. The default method to be used is 'ipw' (which uses inverse probability/propensity weights, after running logistic regression with lasso regularization).
# Using ipw to fit survey weights
adjusted_ipw = sample_with_target.adjust()
INFO (2024-11-21 05:26:34,143) [ipw/ipw (line 424)]: Starting ipw function
INFO (2024-11-21 05:26:34,146) [adjustment/apply_transformations (line 306)]: Adding the variables: []
INFO (2024-11-21 05:26:34,147) [adjustment/apply_transformations (line 307)]: Transforming the variables: ['gender', 'age_group', 'income']
INFO (2024-11-21 05:26:34,159) [adjustment/apply_transformations (line 347)]: Final variables in output: ['gender', 'age_group', 'income']
INFO (2024-11-21 05:26:34,169) [ipw/ipw (line 458)]: Building model matrix
INFO (2024-11-21 05:26:34,259) [ipw/ipw (line 482)]: The formula used to build the model matrix: ['income + gender + age_group + _is_na_gender']
INFO (2024-11-21 05:26:34,260) [ipw/ipw (line 485)]: The number of columns in the model matrix: 16
INFO (2024-11-21 05:26:34,260) [ipw/ipw (line 486)]: The number of rows in the model matrix: 11000
INFO (2024-11-21 05:26:34,267) [ipw/ipw (line 517)]: Fitting logistic model
INFO (2024-11-21 05:26:35,569) [ipw/ipw (line 558)]: max_de: None
INFO (2024-11-21 05:26:35,573) [ipw/ipw (line 588)]: Chosen lambda for cv: [0.0131066]
INFO (2024-11-21 05:26:35,575) [ipw/ipw (line 596)]: Proportion null deviance explained [0.17168419]
adjusted_cbps = sample_with_target.adjust(method = "cbps")
INFO (2024-11-21 05:26:35,587) [cbps/cbps (line 411)]: Starting cbps function
INFO (2024-11-21 05:26:35,589) [adjustment/apply_transformations (line 306)]: Adding the variables: []
INFO (2024-11-21 05:26:35,590) [adjustment/apply_transformations (line 307)]: Transforming the variables: ['gender', 'age_group', 'income']
INFO (2024-11-21 05:26:35,602) [adjustment/apply_transformations (line 347)]: Final variables in output: ['gender', 'age_group', 'income']
INFO (2024-11-21 05:26:35,716) [cbps/cbps (line 462)]: The formula used to build the model matrix: ['income + gender + age_group + _is_na_gender']
INFO (2024-11-21 05:26:35,718) [cbps/cbps (line 474)]: The number of columns in the model matrix: 16
INFO (2024-11-21 05:26:35,718) [cbps/cbps (line 475)]: The number of rows in the model matrix: 11000
INFO (2024-11-21 05:26:35,729) [cbps/cbps (line 537)]: Finding initial estimator for GMM optimization
INFO (2024-11-21 05:26:35,811) [cbps/cbps (line 564)]: Finding initial estimator for GMM optimization that minimizes the balance loss
WARNING (2024-11-21 05:26:36,195) [cbps/cbps (line 581)]: Convergence of bal_loss function has failed due to 'Maximum number of function evaluations has been exceeded.'
INFO (2024-11-21 05:26:36,196) [cbps/cbps (line 599)]: Running GMM optimization
WARNING (2024-11-21 05:26:36,730) [cbps/cbps (line 614)]: Convergence of gmm_loss function with gmm_init start point has failed due to 'Maximum number of function evaluations has been exceeded.'
WARNING (2024-11-21 05:26:37,268) [cbps/cbps (line 632)]: Convergence of gmm_loss function with beta_balance start point has failed due to 'Maximum number of function evaluations has been exceeded.'
INFO (2024-11-21 05:26:37,275) [cbps/cbps (line 730)]: Done cbps function
print(adjusted_ipw)
Adjusted balance Sample object with target set using ipw 1000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness target: balance Sample object 10000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness 3 common variables: gender,age_group,income
# the adjusted object will look the same as ipw
print(adjusted_cbps)
Adjusted balance Sample object with target set using cbps 1000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness target: balance Sample object 10000 observations x 3 variables: gender,age_group,income id_column: id, weight_column: weight, outcome_columns: happiness 3 common variables: gender,age_group,income
Evaluation of the Results (CBPS vs IPW)¶
We can get a basic summary of the results:
print(adjusted_ipw.summary())
Covar ASMD reduction: 59.7%, design effect: 1.897 Covar ASMD (7 variables): 0.327 -> 0.132 Model performance: Model proportion deviance explained: 0.172
print(adjusted_cbps.summary())
Covar ASMD reduction: 77.6%, design effect: 2.782 Covar ASMD (7 variables): 0.327 -> 0.073
We can see that CBPS did a better job in terms of ASMD reduction. Let's look at it per feature:
We see an improvement in the average ASMD. We can look at detailed list of ASMD values per variables using the following call.
print("ipw:")
print(adjusted_ipw.covars().asmd().T)
print("\ncbps:")
print(adjusted_cbps.covars().asmd().T)
ipw:
source self unadjusted unadjusted - self age_group[T.25-34] 0.001085 0.005688 0.004602 age_group[T.35-44] 0.037455 0.312711 0.275256 age_group[T.45+] 0.129304 0.378828 0.249525 gender[Female] 0.133970 0.375699 0.241730 gender[Male] 0.109697 0.379314 0.269617 gender[_NA] 0.042278 0.006296 -0.035983 income 0.243762 0.494217 0.250455 mean(asmd) 0.131675 0.326799 0.195124 cbps: source self unadjusted unadjusted - self age_group[T.25-34] 0.051879 0.005688 -0.046192 age_group[T.35-44] 0.031114 0.312711 0.281597 age_group[T.45+] 0.105655 0.378828 0.273173 gender[Female] 0.034514 0.375699 0.341185 gender[Male] 0.058580 0.379314 0.320733 gender[_NA] 0.041919 0.006296 -0.035624 income 0.111468 0.494217 0.382749 mean(asmd) 0.073118 0.326799 0.253680
It's easier to learn about the biases by just running .covars().plot()
on our adjusted object.
adjusted_ipw.covars().plot(library = "seaborn", dist_type = "kde")
adjusted_cbps.covars().plot(library = "seaborn", dist_type = "kde")
We can also use different plots, using the seaborn library, for example with the "kde" dist_type.
Understanding the weights¶
And get the design effect using:
print("ipw:")
print(adjusted_ipw.weights().design_effect())
print("\ncbps:")
print(adjusted_cbps.weights().design_effect())
ipw: 1.8973847221820574 cbps: 2.7816765614638572
Outcome analysis¶
print(adjusted_ipw.outcomes().summary())
adjusted_ipw.outcomes().plot()
1 outcomes: ['happiness'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci happiness 53.389 56.278 48.559 (52.183, 54.595) (55.961, 56.595) (47.669, 49.449) Response rates (relative to number of respondents in sample): happiness n 1000.0 % 100.0 Response rates (relative to notnull rows in the target): happiness n 1000.0 % 10.0 Response rates (in the target): happiness n 10000.0 % 100.0
The estimated mean happiness according to our sample is 48 without any adjustment and 54 with adjustment. The following show the distribution of happinnes:
print(adjusted_cbps.outcomes().summary())
adjusted_cbps.outcomes().plot()
1 outcomes: ['happiness'] Mean outcomes (with 95% confidence intervals): source self target unadjusted self_ci target_ci unadjusted_ci happiness 54.389 56.278 48.559 (53.02, 55.757) (55.961, 56.595) (47.669, 49.449) Response rates (relative to number of respondents in sample): happiness n 1000.0 % 100.0 Response rates (relative to notnull rows in the target): happiness n 1000.0 % 10.0 Response rates (in the target): happiness n 10000.0 % 100.0
As we can see, CBPS has a larger design effect, but also fixes more of the ASMD and has an impact on the outcome. So there are pros and cons for each of the two methods.
Downloading data¶
Finally, we can prepare the data to be downloaded for future analyses.
adjusted_cbps.to_download()
# We can prepare the data to be exported as csv - showing the first 500 charaacters for simplicity:
adjusted_cbps.to_csv()[0:500]
'id,gender,age_group,income,happiness,weight\n0,Male,25-34,6.428659499046228,26.043028759747298,5.093908021189902\n1,Female,18-24,9.940280228116047,66.88548460632677,0.4137864502389744\n2,Male,18-24,2.6736231547518043,37.091921916683006,2.255219921002779\n3,,18-24,10.550307519418066,49.39405003271002,4.974470708135918\n4,,18-24,2.689993854299385,72.30420755038209,3.343868455923355\n5,,35-44,5.995497722733131,57.28281646341816,17.083435577163435\n6,,18-24,12.63469573898972,31.663293445944596,5.5913639935'
# Sessions info
import session_info
session_info.show(html=False, dependencies=True)
----- balance 0.9.1 pandas 1.4.3 session_info 1.0.0 ----- PIL 11.0.0 anyio NA apport_python_hook NA argcomplete NA arrow 1.3.0 asttokens NA attr 24.2.0 attrs 24.2.0 babel 2.16.0 beta_ufunc NA binom_ufunc NA certifi 2020.06.20 chardet 4.0.0 charset_normalizer 3.4.0 colorama 0.4.4 comm 0.2.2 coxnet NA cvcompute NA cvelnet NA cvfishnet NA cvglmnet NA cvglmnetCoef NA cvglmnetPredict NA cvlognet NA cvmrelnet NA cvmultnet NA cycler 0.12.1 cython_runtime NA dateutil 2.9.0.post0 debugpy 1.8.8 decorator 5.1.1 defusedxml 0.7.1 elnet NA exceptiongroup 1.2.2 executing 2.1.0 fastjsonschema NA fishnet NA fqdn NA gi 3.42.1 gio NA glib NA glmnet NA glmnetCoef NA glmnetControl NA glmnetPredict NA glmnetSet NA glmnet_python NA gobject NA gtk NA hypergeom_ufunc NA idna 3.3 ipfn NA ipykernel 6.29.5 isoduration NA jedi 0.19.2 jinja2 3.1.4 joblib 1.4.2 json5 0.9.28 jsonpointer 2.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.7 loadGlmLib NA lognet NA markupsafe 2.0.1 matplotlib 3.9.2 matplotlib_inline 0.1.7 mpl_toolkits NA mrelnet NA nbformat 5.10.4 nbinom_ufunc NA ncf_ufunc NA numpy 1.24.4 overrides NA packaging 24.2 parso 0.8.4 patsy 1.0.1 platformdirs 4.3.6 plotly 5.24.1 prometheus_client NA prompt_toolkit 3.0.48 psutil 6.1.0 pure_eval 0.2.3 pydev_ipython NA pydevconsole NA pydevd 3.2.2 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.18.0 pyparsing 2.4.7 pythonjsonlogger NA pytz 2022.1 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.9.1 seaborn 0.13.0 send2trash NA sitecustomize NA six 1.16.0 sklearn 1.5.2 sniffio 1.3.1 sphinxcontrib NA stack_data 0.6.3 statsmodels 0.14.4 tenacity NA threadpoolctl 3.5.0 tornado 6.4.1 traitlets 5.14.3 typing_extensions NA uri_template NA urllib3 1.26.5 wcwidth 0.2.13 webcolors NA websocket 1.8.0 wtmean NA yaml 5.4.1 zmq 26.2.0 zoneinfo NA zope NA ----- IPython 8.29.0 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.2.6 notebook 7.2.2 ----- Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] Linux-6.5.0-1025-azure-x86_64-with-glibc2.35 ----- Session information updated at 2024-11-21 05:26