Skip to content
Snippets Groups Projects
Commit fd3064e9 authored by Shengpu Tang (tangsp)'s avatar Shengpu Tang (tangsp)
Browse files

Add all files

parent fe6333fb
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import scipy.stats
import pickle, os, time
import itertools
from datetime import datetime, timedelta
from collections import Counter, defaultdict, namedtuple
from PIL import Image
import yaml
from tqdm import tqdm
import seaborn as sns
from matplotlib import pyplot as plt
```
%% Cell type:code id: tags:
``` python
data_dir = './data/'
```
%% Cell type:code id: tags:
``` python
# Load the population, labels, and baseline features
pop = pd.read_csv(data_dir + 'population/d10_with_vitals.csv').set_index('BMT_ID')
df_label_full = pop.join(pd.read_csv(data_dir + 'prep/label.csv', index_col='BMT_ID'), how='left')
df_static = pop.join(pd.read_csv(data_dir + 'features/static.csv', index_col='BMT_ID'), how='left')
df_static.index.rename('id', inplace=True)
df_label = df_label_full['Label_GVHD']
df_label34 = (df_label_full['GVHD_max_grade'] >= 3).astype(int)
assert not df_static[df_static.isnull().any(axis=1)].any().any()
# Load the vital sign time series
ts_vitals_by_bmt = pickle.load(open(data_dir + 'features/ts_vitals_by_bmt_2014_2017_MiChart.p', 'rb'))
ts_vitals_by_bmt = {ID: ts_vitals_by_bmt[ID] for ID in list(pop.index)}
```
%% Cell type:code id: tags:
``` python
print('Population size:', len(ts_vitals_by_bmt))
print()
print('Class balance')
print('{{0,1}} vs. {{2,3,4}}: \t{:.1%}'.format(df_label.mean()))
print('{{0,1,2}} vs. {{3,4}}: \t{:.1%}'.format(df_label34.mean()))
```
%% Output
Population size: 324
Class balance
{0,1} vs. {2,3,4}: 31.8%
{0,1,2} vs. {3,4}: 13.6%
%% Cell type:code id: tags:
``` python
# Extract vital sign features
variables = ['HR', 'RR', 'SysBP', 'DiaBP', 'Temp', 'SpO2']
t0, T = 0, 10
dt = 1
import tsfresh
def get_trend_features(t0, T, dt):
# Recursively summarizing
# First computes summary for every daily window (dt)
# Then computes trend features
# Daily vitals summary statistics
D_tmp = {}
for key, df in ts_vitals_by_bmt.items():
df = df[(t0 <= df['t']) & (df['t'] < T)].set_index('t').copy()
df = df.rename(columns={v: '{}_dt={}'.format(v, dt) for v in variables})
df = df.groupby(
pd.cut(df.index, np.arange(t0, T+dt, dt), right=False)
).agg([
'mean', 'std',
'min', 'max',
])
df.index.rename('t', inplace=True)
D_tmp[key] = df.reset_index()
timeseries = pd.concat(D_tmp)
timeseries.index.rename('id', level=0, inplace=True)
timeseries = timeseries.sort_index()
timeseries = timeseries.reset_index(level=0).set_index(['id', 't'])
timeseries.columns = ['_'.join(col).strip() for col in timeseries.columns.values]
stacked_ts = timeseries.stack().copy()
stacked_ts.index.rename('variable', level=-1, inplace=True)
stacked_ts.rename('value', inplace=True)
stacked_ts = stacked_ts.reset_index()
assert not pd.isnull(stacked_ts['value']).any()
feature_params = {
'mean': None,
'linear_trend': [{'attr': 'slope'}],
'sample_entropy': None,
'fft_coefficient': [
{'coeff': 1, 'attr': 'abs'},
{'coeff': 1, 'attr': 'angle'},
],
}
extracted_features = tsfresh.extract_features(
stacked_ts, column_id='id', column_sort='t', column_kind='variable', column_value='value',
default_fc_parameters=feature_params,
)
return extracted_features
extracted_features = get_trend_features(t0, T, dt)
```
%% Output
Feature Extraction: 100%|██████████| 278/278 [00:01<00:00, 263.91it/s]
%% Cell type:code id: tags:
``` python
extracted_features.to_csv('data/ts_features.csv')
```
%% Cell type:code id: tags:
``` python
# Bin values by quintiles
df_features = df_static.join(
pd.get_dummies(extracted_features.apply(pd.qcut, q=5, duplicates='drop'), prefix_sep='_')
)
print(df_features.shape)
```
%% Output
(324, 652)
%% Cell type:code id: tags:
``` python
df_features.head()
```
%% Output
Age_(-0.001, 18.0] Age_(18.0, 45.0] Age_(45.0, 75.0] \
id
201406001 0 0 1
201406002 0 0 1
201406003 0 0 1
201406004 0 0 1
201406005 0 0 1
Disease Code category_Malignant \
id
201406001 1
201406002 1
201406003 1
201406004 1
201406005 1
Disease Code category_Non-malignant \
id
201406001 0
201406002 0
201406003 0
201406004 0
201406005 0
Disease Risk_0 - Non-malignant Disease Risk_1 - Low \
id
201406001 0 0
201406002 0 0
201406003 0 1
201406004 0 0
201406005 0 0
Disease Risk_2 - Intermediate Disease Risk_3 - High \
id
201406001 1 0
201406002 1 0
201406003 0 0
201406004 0 1
201406005 0 1
Intensity_0 - Full ... \
id ...
201406001 1 ...
201406002 1 ...
201406003 1 ...
201406004 0 ...
201406005 1 ...
Temp_dt=1_std__mean_(0.0977, 0.157] \
id
201406001 0
201406002 0
201406003 0
201406004 0
201406005 0
Temp_dt=1_std__mean_(0.157, 0.194] \
id
201406001 0
201406002 0
201406003 0
201406004 0
201406005 0
Temp_dt=1_std__mean_(0.194, 0.235] \
id
201406001 1
201406002 0
201406003 0
201406004 1
201406005 0
Temp_dt=1_std__mean_(0.235, 0.292] \
id
201406001 0
201406002 1
201406003 1
201406004 0
201406005 1
Temp_dt=1_std__mean_(0.292, 0.564] \
id
201406001 0
201406002 0
201406003 0
201406004 0
201406005 0
Temp_dt=1_std__sample_entropy_(0.67, 1.861] \
id
201406001 0
201406002 0
201406003 0
201406004 1
201406005 0
Temp_dt=1_std__sample_entropy_(1.861, 2.197] \
id
201406001 0
201406002 0
201406003 1
201406004 0
201406005 1
Temp_dt=1_std__sample_entropy_(2.197, 2.42] \
id
201406001 1
201406002 0
201406003 0
201406004 0
201406005 0
Temp_dt=1_std__sample_entropy_(2.42, 2.708] \
id
201406001 0
201406002 1
201406003 0
201406004 0
201406005 0
Temp_dt=1_std__sample_entropy_(2.708, 3.807]
id
201406001 0
201406002 0
201406003 0
201406004 0
201406005 0
[5 rows x 652 columns]
%% Cell type:code id: tags:
``` python
df_features.to_csv('data/df_features.csv')
```
%% Cell type:code id: tags:
``` python
X = df_features.values
y = df_label.values
# Make sure there are no nan values
assert not np.isnan(X).any()
assert not np.isnan(y).any()
```
%% Cell type:code id: tags:
``` python
y34 = df_label34.values
assert not np.isnan(y34).any()
```
%% Cell type:code id: tags:
``` python
X.shape, y.shape
```
%% Output
((324, 652), (324,))
%% Cell type:code id: tags:
``` python
np.savez('data/Xy.npz', X=X, y=y, y34=y34)
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
import random
import numpy as np
import pandas as pd
from joblib import dump, load
np.random.seed(42)
random.seed(42)
```
%% Cell type:code id: tags:
``` python
data_dir = './data/'
with np.load('data/Xy.npz') as f:
X = f['X']
y = f['y']
```
%% Cell type:code id: tags:
``` python
# Perform temporal split of data into train/test sets
pop = pd.read_csv(data_dir + 'population/d10_with_vitals.csv').set_index('BMT_ID')
split_date = 201701001
split_idx = -85
assert (pop[:split_idx].index < split_date).all()
assert (pop[split_idx:].index >= split_date).all()
```
%% Cell type:code id: tags:
``` python
from sklearn import preprocessing, model_selection, metrics, utils
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
from joblib import Parallel, delayed
from sklearn.base import clone
```
%% Cell type:code id: tags:
``` python
# Specify hyperparameters and cv parameters
base_estimator = LogisticRegression(penalty='l2', class_weight='balanced', solver='liblinear')
param_grid = {
'C': [10. ** n for n in range(-6, 7)],
'penalty': ['l2'],
}
```
%% Cell type:markdown id: tags:
## Train model with baseline+vitals
%% Cell type:code id: tags:
``` python
Xtr, Xte = X[:split_idx], X[split_idx:]
ytr, yte = y[:split_idx], y[split_idx:]
cv_splits, cv_repeat = 5, 20
cv = model_selection.RepeatedStratifiedKFold(cv_splits, cv_repeat, random_state=0)
clf = model_selection.GridSearchCV(
clone(base_estimator), param_grid,
cv=cv, scoring='roc_auc', n_jobs=5,
)
clf.fit(Xtr, ytr)
test_score = metrics.roc_auc_score(yte, clf.decision_function(Xte))
```
%% Cell type:code id: tags:
``` python
y_true = yte
y_score = clf.decision_function(Xte)
def boostrap_func(i, y_true, y_score):
yte_true_b, yte_pred_b = utils.resample(y_true, y_score, replace=True, random_state=i)
return metrics.roc_curve(yte_true_b, yte_pred_b), metrics.roc_auc_score(yte_true_b, yte_pred_b)
roc_curves, auc_scores = zip(*Parallel(n_jobs=4)(delayed(boostrap_func)(i, y_true, y_score) for i in tqdm(range(1000), leave=False)))
print('Test AUC: {:.3f} ({:.3f}, {:.3f})'.format(np.median(auc_scores), np.percentile(auc_scores, 2.5), np.percentile(auc_scores, 97.5)))
```
%% Output
Test AUC: 0.658 (0.536, 0.784)
Test AUC: 0.659 ± 0.063
%% Cell type:code id: tags:
``` python
dump(clf, 'data/model_combined.joblib')
```
%% Output
['data/model_combined.joblib']
%% Cell type:markdown id: tags:
## Train model with baseline features only
%% Cell type:code id: tags:
``` python
Xtr, Xte = X[:split_idx, :52], X[split_idx:, :52]
ytr, yte = y[:split_idx], y[split_idx:]
cv_splits, cv_repeat = 5, 20
cv = model_selection.RepeatedStratifiedKFold(cv_splits, cv_repeat, random_state=0)
clf = model_selection.GridSearchCV(
clone(base_estimator), param_grid,
cv=cv, scoring='roc_auc', n_jobs=5,
)
clf.fit(Xtr, ytr)
test_score = metrics.roc_auc_score(yte, clf.decision_function(Xte))
```
%% Output
/data4/tangsp/venv/lib/python3.7/site-packages/sklearn/model_selection/_search.py:814: DeprecationWarning: The default of the `iid` parameter will change from True to False in version 0.22 and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.
DeprecationWarning)
%% Cell type:code id: tags:
``` python
y_true = yte
y_score = clf.decision_function(Xte)
def boostrap_func(i, y_true, y_score):
yte_true_b, yte_pred_b = utils.resample(y_true, y_score, replace=True, random_state=i)
return metrics.roc_curve(yte_true_b, yte_pred_b), metrics.roc_auc_score(yte_true_b, yte_pred_b)
roc_curves, auc_scores = zip(*Parallel(n_jobs=4)(delayed(boostrap_func)(i, y_true, y_score) for i in tqdm(range(1000), leave=False)))
print('Test AUC: {:.3f} ({:.3f}, {:.3f})'.format(np.median(auc_scores), np.percentile(auc_scores, 2.5), np.percentile(auc_scores, 97.5)))
```
%% Output
Test AUC: 0.512 (0.364, 0.643)
%% Cell type:code id: tags:
``` python
dump(clf, 'data/model_baseline.joblib')
```
%% Output
['data/model_baseline.joblib']
%% Cell type:markdown id: tags:
## Train model with vitals features only
%% Cell type:code id: tags:
``` python
Xtr, Xte = X[:split_idx, 52:], X[split_idx:, 52:]
ytr, yte = y[:split_idx], y[split_idx:]
cv_splits, cv_repeat = 5, 20
cv = model_selection.RepeatedStratifiedKFold(cv_splits, cv_repeat, random_state=0)
clf = model_selection.GridSearchCV(
clone(base_estimator), param_grid,
cv=cv, scoring='roc_auc', n_jobs=5,
)
clf.fit(Xtr, ytr)
test_score = metrics.roc_auc_score(yte, clf.decision_function(Xte))
```
%% Cell type:code id: tags:
``` python
y_true = yte
y_score = clf.decision_function(Xte)
def boostrap_func(i, y_true, y_score):
yte_true_b, yte_pred_b = utils.resample(y_true, y_score, replace=True, random_state=i)
return metrics.roc_curve(yte_true_b, yte_pred_b), metrics.roc_auc_score(yte_true_b, yte_pred_b)
roc_curves, auc_scores = zip(*Parallel(n_jobs=4)(delayed(boostrap_func)(i, y_true, y_score) for i in tqdm(range(1000), leave=False)))
print('Test AUC: {:.3f} ({:.3f}, {:.3f})'.format(np.median(auc_scores), np.percentile(auc_scores, 2.5), np.percentile(auc_scores, 97.5)))
```
%% Output
Test AUC: 0.633 (0.507, 0.757)
%% Cell type:code id: tags:
``` python
dump(clf, 'data/model_vitals.joblib')
```
%% Output
['data/model_vitals.joblib']
%% Cell type:code id: tags:
``` python
```
This diff is collapsed.
This diff is collapsed.
# JCO CCI - aGVHD_prediction
## Overview
- This is the code repository for the manuscript "Predicting Acute Graft-versus-Host Disease Using Machine Learning and Vital Sign Data from Electronic Health Records".
- Authors: Shengpu Tang, Grant Chappell, Amanda Mazzoli, Muneesh Tewari, Sung Won Choi\*, and Jenna Wiens\*
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment