Commit e3fc4342 authored by Shengpu Tang (tangsp)'s avatar Shengpu Tang (tangsp)
Browse files

experiments: data extraction

parent f4a1c687
%% Cell type:code id: tags:
``` python
import numpy as np
import pandas as pd
import os, sys, time
from datetime import datetime, timedelta
import pickle
from collections import Counter
```
%% Cell type:code id: tags:
``` python
import yaml
config = yaml.safe_load(open('../config.yaml'))
data_path = config['data_path']
mimic3_path = config['mimic3_path']
import pathlib
pathlib.Path(data_path, 'population').mkdir(parents=True, exist_ok=True)
```
%% Cell type:code id: tags:
``` python
patients = pd.read_csv(mimic3_path + 'PATIENTS.csv', parse_dates=['DOB', 'DOD'], usecols=['SUBJECT_ID', 'DOB', 'DOD'])
admissions = pd.read_csv(mimic3_path + 'ADMISSIONS.csv', parse_dates=['DEATHTIME'], usecols=['SUBJECT_ID', 'HADM_ID', 'DEATHTIME', 'HOSPITAL_EXPIRE_FLAG'])
examples = pd.read_csv(data_path + 'prep/icustays_MV.csv', parse_dates=['INTIME', 'OUTTIME']).sort_values(by='ICUSTAY_ID') # Only Metavision
examples = pd.merge(examples, patients, on='SUBJECT_ID', how='left')
examples = pd.merge(examples, admissions, on=['SUBJECT_ID', 'HADM_ID'], how='left')
examples['AGE'] = examples.apply(lambda x: (x['INTIME'] - x['DOB']).total_seconds(), axis=1) / 3600 / 24 / 365.25
examples['LOS'] = examples['LOS'] * 24 # Convert to hours
```
%% Cell type:code id: tags:
``` python
tasks = ['ARF', 'Shock']
label_defs = { task: pd.read_csv(data_path + 'labels/{}.csv'.format(task)) for task in tasks }
```
%% Cell type:code id: tags:
``` python
# Start
N = len(examples['ICUSTAY_ID'].unique())
print('Source population', N)
```
%%%% Output: stream
Source population 23620
%% Cell type:code id: tags:
``` python
assert (examples['INTIME'] <= examples['OUTTIME']).all()
assert (examples['DBSOURCE'] == 'metavision').all()
```
%% Cell type:code id: tags:
``` python
# Remove non-adults
min_age = 18
max_age = np.inf # no max age
examples = examples[(examples.AGE >= min_age) & (examples.AGE <= max_age)]
print('Exclude non-adults', examples['ICUSTAY_ID'].nunique())
examples_ = examples
```
%%%% Output: stream
Exclude non-adults 23593
%% Cell type:code id: tags:
``` python
for T in [4, 12]:
print('======')
print('prediction time', T, 'hour')
# Remove died before cutoff hour
examples = examples_[(examples_.DEATHTIME >= examples_.INTIME + timedelta(hours=T)) | (examples_.DEATHTIME.isnull())]
print('Exclude deaths', examples['ICUSTAY_ID'].nunique())
# Remove LOS < cutoff hour
examples = examples[examples['LOS'] >= T]
print('Exclude discharges', examples['ICUSTAY_ID'].nunique())
populations = {}
# Remove event onset before (cutoff)
for task in tasks:
print('---')
print('Outcome', task)
label_def = label_defs[task]
# Needed to preserve index in DataFrame
pop = examples[['ICUSTAY_ID']].reset_index() \
.merge(label_def[['ICUSTAY_ID', '{}_ONSET_HOUR'.format(task)]], on='ICUSTAY_ID', how='left') \
.set_index('index').copy()
pop = pop[(pop['{}_ONSET_HOUR'.format(task)] >= T) | pop['{}_ONSET_HOUR'.format(task)].isnull()]
pop['{}_LABEL'.format(task)] = pop['{}_ONSET_HOUR'.format(task)].notnull().astype(int)
pop.to_csv(data_path + 'population/{}_{}h.csv'.format(task, T), index=False)
# Construct boolean mask
## NOTE: uses pop.index here, assuming index is preserved
idx = pop.index
## Otherwise, there's a slower version
# if False:
# idx = np.array([examples[examples.ICUSTAY_ID == i].index[0] for i in pop['ICUSTAY_ID']])
mask_array = np.zeros(N, dtype=bool)
mask_array[idx] = True
# Save population boolean mask
np.save(data_path + 'population/mask_{}_{}h.npy'.format(task, T), mask_array)
np.savetxt(data_path + 'population/mask_{}_{}h.txt'.format(task, T), mask_array, fmt='%i')
populations[task] = pop
print('Exclude onset', len(pop))
```
%%%% Output: stream
======
prediction time 4 hour
Exclude deaths 23499
Exclude discharges 23401
---
Outcome ARF
Exclude onset 15873
---
Outcome Shock
Exclude onset 19342
======
prediction time 12 hour
Exclude deaths 23319
Exclude discharges 23060
---
Outcome ARF
Exclude onset 14174
---
Outcome Shock
Exclude onset 17588
%% Cell type:code id: tags:
``` python
for T in [48]:
print('======')
print('prediction time', T, 'hour')
# Remove died before cutoff hour
examples = examples_[(examples_.DEATHTIME >= examples_.INTIME + timedelta(hours=T)) | (examples_.DEATHTIME.isnull())]
print('Exclude deaths', examples['ICUSTAY_ID'].nunique())
# Remove LOS < cutoff hour
examples = examples[examples['LOS'] >= T]
print('Exclude discharges', examples['ICUSTAY_ID'].nunique())
# Remove event onset before (cutoff)
for task in ['mortality']:
print('---')
print('Outcome', task)
examples['{}_LABEL'.format(task)] = examples.HOSPITAL_EXPIRE_FLAG
pop = examples[['ICUSTAY_ID', '{}_LABEL'.format(task)]]
pop.to_csv(data_path + 'population/{}_{}h.csv'.format(task, T), index=False)
print('Exclude onset', len(pop))
```
%%%% Output: stream
======
prediction time 48 hour
Exclude deaths 22776
Exclude discharges 11695
---
Outcome mortality
Exclude onset 11695
%% Cell type:code id: tags:
``` python
```
This diff is collapsed.
This diff is collapsed.
import os, yaml
with open(os.path.join(os.path.dirname(__file__), '../config.yaml')) as f:
config = yaml.full_load(f)
data_path = os.path.join(os.path.dirname(__file__), config['data_path'])
mimic3_path = os.path.join(os.path.dirname(__file__), config['mimic3_path'])
parallel = True
n_jobs = 72
This diff is collapsed.
"""
generate_labels.py
Author: Shengpu Tang
Generate labels for two adverse outcomes: ARF and shock.
"""
import pandas as pd
import numpy as np
import scipy.stats
import itertools
from collections import OrderedDict, Counter
from joblib import Parallel, delayed
from tqdm import tqdm as tqdm
import yaml
data_path = yaml.full_load(open('../config.yaml'))['data_path']
import pathlib
pathlib.Path(data_path, 'labels').mkdir(parents=True, exist_ok=True)
examples = pd.read_csv(data_path + 'prep/icustays_MV.csv', parse_dates=['INTIME', 'OUTTIME']).sort_values(by='ICUSTAY_ID')
chartevents = pd.read_pickle(data_path + 'prep/chartevents.p')
procedures = pd.read_pickle(data_path + 'prep/procedureevents_mv.p')
inputevents = pd.read_pickle(data_path + 'prep/inputevents_mv.p')
ventilation = [
'225792', # Invasive Ventilation
'225794', # Non-invasive Ventilation
]
PEEP = [
'220339', # PEEP set
]
vasopressors = [
'221906', # Norepinephrine
'221289', # Epinephrine
'221662', # Dopamine
'222315', # Vasopressin
'221749', # Phenylephrine
]
## ARF: (PEEP) OR (mechanical ventilation)
df_PEEP = chartevents[chartevents.ITEMID.isin(PEEP)].copy()
df_vent = procedures[procedures.ITEMID.isin(ventilation)].rename(columns={'t_start': 't'}).copy()
df_ARF = pd.concat([df_PEEP[['ICUSTAY_ID', 't']], df_vent[['ICUSTAY_ID', 't']]], axis=0)
df_ARF['ICUSTAY_ID'] = df_ARF['ICUSTAY_ID'].astype(int)
df_ARF = df_ARF.sort_values(by=['ICUSTAY_ID', 't']).drop_duplicates(['ICUSTAY_ID'], keep='first').reset_index(drop=True)
df_ARF = df_ARF.rename(columns={'t': 'ARF_ONSET_HOUR'})
df_ARF = pd.merge(examples[['ICUSTAY_ID']], df_ARF, on='ICUSTAY_ID', how='left')
df_ARF['ARF_LABEL'] = df_ARF['ARF_ONSET_HOUR'].notnull().astype(int)
print('ARF: ', dict(Counter(df_ARF['ARF_LABEL'])), 'N = {}'.format(len(df_ARF)), sep='\t')
df_ARF.to_csv(data_path + 'labels/ARF.csv', index=False)
## Shock: (one of vasopressors)
df_vaso = inputevents[inputevents.ITEMID.isin(vasopressors)].rename(columns={'t_start': 't'}).copy()
df_shock = df_vaso.copy()
df_shock['ICUSTAY_ID'] = df_shock['ICUSTAY_ID'].astype(int)
df_shock = df_shock.sort_values(by=['ICUSTAY_ID', 't']).drop_duplicates(['ICUSTAY_ID'], keep='first').reset_index(drop=True)
df_shock = df_shock.rename(columns={'t': 'Shock_ONSET_HOUR'})
df_shock = pd.merge(examples[['ICUSTAY_ID']], df_shock, on='ICUSTAY_ID', how='left')
df_shock['Shock_LABEL'] = df_shock['Shock_ONSET_HOUR'].notnull().astype(int)
print('Shock: ', dict(Counter(df_shock['Shock_LABEL'])), 'N = {}'.format(len(df_shock)), sep='\t')
df_shock.to_csv(data_path + 'labels/Shock.csv', index=False)
HR:
- 220045 # Heart Rate
SysBP:
- 224167 # Manual Blood Pressure Systolic Left
- 227243 # Manual Blood Pressure Systolic Right
- 220050 # Arterial Blood Pressure systolic
- 220179 # Non Invasive Blood Pressure systolic
- 225309 # ART BP Systolic
DiaBP:
- 224643 # Manual Blood Pressure Diastolic Left
- 227242 # Manual Blood Pressure Diastolic Right
- 220051 # Arterial Blood Pressure diastolic
- 220180 # Non Invasive Blood Pressure diastolic
- 225310 # ART BP Diastolic
RR:
- 220210 # Respiratory Rate
- 224690 # Respiratory Rate (Total)
Temperature:
- 223761 # Temperature Fahrenheit
- 223762 # Temperature Celsius
SpO2:
- 220277 # O2 saturation pulseoxymetry
Height:
- 226707 # Height
- 226730 # Height (cm)
Weight:
- 224639 # Daily Weight
- 226512 # Admission Weight (Kg)
- 226531 # Admission Weight (lbs.)
%% Cell type:markdown id: tags:
Multitask benchmark: https://www.nature.com/articles/s41597-019-0103-9
%% Cell type:code id: tags:
``` python
import yaml
with open('../config.yaml') as f:
config = yaml.full_load(f)
data_path = config['data_path']
```
%% Cell type:code id: tags:
``` python
import pandas as pd
from collections import defaultdict
```
%% Cell type:code id: tags:
``` python
df_train = pd.read_csv('train_listfile.csv')
df_val = pd.read_csv('val_listfile.csv')
df_test = pd.read_csv('test_listfile.csv')
```
%% Cell type:code id: tags:
``` python
df_train['y_true'].mean()
```
%%%% Output: execute_result
0.13534500374633882
%% Cell type:markdown id: tags:
## Removing non-metavision ICU stays
%% Cell type:code id: tags:
``` python
icustays = pd.read_csv('all_stays.csv')
icustays = icustays.sort_values(by=['SUBJECT_ID', 'INTIME', 'OUTTIME']).reset_index(drop=True)
metavision = icustays[icustays['DBSOURCE'] == 'metavision']['ICUSTAY_ID']
```
%% Cell type:code id: tags:
``` python
stays_by_subjects = defaultdict(list)
for i, (j, k) in icustays[['SUBJECT_ID', 'ICUSTAY_ID']].iterrows():
stays_by_subjects[j].append(k)
```
%% Cell type:code id: tags:
``` python
my_labels = pd.read_csv('../' + data_path + 'population/mortality_48h.csv').set_index('ICUSTAY_ID')
```
%% Cell type:code id: tags:
``` python
df_out = []
for part, df_part in zip(['train', 'val', 'test'], [df_train, df_val, df_test]):
for i, (name, y) in df_part.iterrows():
try:
ID, ep, _ = name.split('_')
ID = int(ID)
ep = int(ep[7:]) - 1
stay_ID = stays_by_subjects[ID][ep]
if stay_ID in metavision.values and stay_ID in my_labels.index:
# Only keep patients that are recorded using metavision that have not died by 48 hour
df_out.append((stay_ID, name, part, y, y))
my_y = my_labels.loc[stay_ID, 'mortality_LABEL']
else:
continue
except:
print(name, ID, stay_ID, part, y, my_y)
```
%% Cell type:code id: tags:
``` python
df_out = pd.DataFrame(df_out, columns=['ID', 'stay', 'partition', 'mortality_LABEL', 'y_true'])
```
%% Cell type:code id: tags:
``` python
df_out = df_out.sort_values(by='ID')
```
%% Cell type:code id: tags:
``` python
df_out.to_csv('../' + data_path + 'population/pop.mortality_benchmark.csv', index=False)
```
%% Cell type:code id: tags:
``` python
df_out['mortality_LABEL'].mean()
```
%%%% Output: execute_result
0.12020519995336365
%% Cell type:code id: tags:
``` python
```
This diff is collapsed.
"""
python prepare_input.py
"""
import os, yaml
with open(os.path.join(os.path.dirname(__file__), '../config.yaml')) as f:
config = yaml.full_load(f)
data_path = os.path.join(os.path.dirname(__file__), config['data_path'])
parallel = True
ID_col = config['column_names']['ID']
t_col = config['column_names']['t']
var_col = config['column_names']['var_name']
val_col = config['column_names']['var_value']
import argparse
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--outcome', type=str, required=True)
parser.add_argument('--T', type=int, required=True)
parser.add_argument('--dt', type=float, required=True)
args = parser.parse_args()
task = args.outcome
T = args.T
dt = args.dt
print('Preparing pipeline input for: outcome={}, T={}, dt={}'.format(task, T, dt))
import pathlib
pathlib.Path(data_path, 'features', 'outcome={}.T={}.dt={}'.format(task, T, dt)) \
.mkdir(parents=True, exist_ok=True)
# Load in study population
population = pd.read_csv(data_path + 'population/{}_{}h.csv'.format(task, T)) \
.rename(columns={'ICUSTAY_ID': 'ID'}).set_index('ID')[[]]
# Load in raw data (from prepare.py)
with open(data_path + 'formatted/all_data.stacked.p', 'rb') as f:
data = pickle.load(f)
# with open(data_path + 'formatted/reformatted_data.T={}.dt={}.p'.format(T, 0.5), 'rb') as f:
# data = pickle.load(f)
####### TODO: Refactor: resample continuous, resolve duplicates (discrete & continuous)
# data = resolve_duplicates_discrete_old(data)
# data = resample_continuous_events_old(data, T, dt) # includes filtering by time window
# data = filter_prediction_time(data, T)
# data = resample_continuous_events(data, T, dt)
# data = resolve_duplicates(data)
data = resolve_duplicates_discrete(data)
data = filter_prediction_time(data, T)
data = resample_continuous_events(data, T, dt)
data = resolve_duplicates_continuous(data)
# Combine all DataFrames into one
df_data = pd.concat(data, axis='index', ignore_index=True)
df_data = df_data.sort_values(by=[ID_col, t_col, var_col, val_col], na_position='first')
# Filter by IDs in study population
df_data = population.join(df_data.set_index('ID')).reset_index()
assert set(df_data['ID'].unique()) == set(population.index)
# Save
df_data.to_pickle(data_path + 'features/outcome={}.T={}.dt={}/input_data.p'.format(task, T, dt))
################################
#### Helper functions ####
################################
def print_header(*content, char='='):
print()
print(char * 80)
print(*content)
print(char * 80, flush=True)
def filter_prediction_time(data_in, T):
"""
Filter each table in `data_in` by:
- Removing records outside of the prediction window [0, T) hours
`data_in` is a dict {
TABLE_NAME: pd.DataFrame object,
}
"""
print_header('Filter by prediction time T={}'.format(T), char='-')
filtered_data = {}
for table_name in tqdm(sorted(data_in)):
df = data_in[table_name]
t_cols = t_cols = df.columns.intersection(['t', 't_start', 't_end']).tolist()
# Focus on the prediction window of [0, T)
if len(t_cols) == 1: # point
if all(pd.isnull(df['t'])):
pass
else:
df = df[(0 <= df['t']) & (df['t'] < T)].copy()
elif len(t_cols) == 2: # range
df = df[(0 <= df['t_end']) & (df['t_start'] < T)].copy()
filtered_data[table_name] = df
print('Done!')
return filtered_data
def resample_continuous_events(data, T, dt):
print_header('Resample continuous events, T={}, dt={}'.format(T, dt), char='-')
for fname, df in sorted(data.items(), reverse=True):
t_cols = df.columns.intersection(['t', 't_start', 't_end']).tolist()
if len(t_cols) == 1: # point time
continue
else: # ranged time
assert len(t_cols) == 2
print(fname)
df_out = []
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
t_start, t_end = row.t_start, row.t_end
t_range = dt/2 + np.arange(max(0, (t_start//dt)*dt), min(T, (t_end//dt+1)*dt), dt)
if len(t_range) == 0:
continue
df_tmp = pd.concat(len(t_range) * [row], axis=1).T.drop(columns=['t_start', 't_end'])
df_tmp['t'] = t_range
df_out.append(df_tmp)
df_out = pd.concat(df_out)[['ID', 't', 'variable_name', 'variable_value']]
data[fname] = df_out
return data
def resolve_duplicates_discrete(data):
"""
Assume input format:
––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
| ID | t (or t_start + t_end) | variable_name | variable_value |
––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
"""
print_header('Resolve duplicated event records (discrete)', char='-')
### Chart events - duplicate rows
print('*** CHARTEVENTS')
print(' getting dups and ~dups')
df = data['CHARTEVENTS']
m_dups = df.duplicated(subset=['ID', 't', 'variable_name'], keep=False)
dups = df[m_dups]
dup_variables = dups['variable_name'].unique()