This repository contains code to replicate experiments in the 2022 NeurIPS paper, “Learning Concept Credible Models for Mitigating Shortcuts”.
Given access to a representation based on domain knowledge (i.e. known concepts), we want to learn a model that is accurate regardless of whether the training data is biased (i.e., containing shortcuts that do not hold in practice) and whether the known concepts alone are sufficient for accurate predictions. We call such a model a concept credible model (CCM). To achieve that end, we proposed 2 methods, CCM EYE and CCM RES, that is provably concept credible in some linear settings and can empirically mitigate learning shortcuts even when assumptions are broken.
The code directories are organized as the following
mimic_scripts/
contains the training code for reproducing experiments on MIMIC-CXR dataset.
scripts/
contains the training code for reproducing experiments on CUB birds dataset.
notebooks/
contains ipython notebook for visualization of the results.
Dependencies are listed in Pipfile and can be installed with pipenv.
To run baseline models for the CUB dataset:
Getting concept C:
python scripts/concept_model.py --transform flip --lr_step 1000 -t 0 -s noise --n_shortcuts 10
Oracle CBM used to generate shortcut
python scripts/cbm.py --lr_step 15 -s noise -t 1 --n_shortcuts 10 --c_model_path <path to C>/concept
For CBM
python scripts/cbm.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts
For STD(X)
python scripts/standard_model.py -s <path to oracle CBM>/cbm.pt --n_shortcuts 10 -t 1
For STD(C, X)
python scripts/ccm.py --lr_step 15 --alpha 0 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)> --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts
For CCM RES
python scripts/ccm_r.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to CBM>/cbm
For CCM EYE
python scripts/ccm.py --lr_step 15 --alpha 0.001 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to C>/concepts
logging
I log all the commands ran using
track log
see how to use my command tracking
track -h