Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ For applicable tasks, we provide the relevant metrics for model calibration, as
Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets
We also provide other metrics specically for healthcare
tasks, such as drug drug interaction (DDI) rate.
For synthetic (generative) EHR data, we provide privacy, utility, and statistical
fidelity metrics.


.. toctree::
Expand All @@ -19,3 +21,4 @@ tasks, such as drug drug interaction (DDI) rate.
metrics/pyhealth.metrics.prediction_set
metrics/pyhealth.metrics.fairness
metrics/pyhealth.metrics.interpretability
metrics/pyhealth.metrics.generative
25 changes: 25 additions & 0 deletions docs/api/metrics/pyhealth.metrics.generative.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pyhealth.metrics.generative
===================================

Evaluation metrics for synthetic (generative) EHR data, covering privacy,
utility, and statistical fidelity.

.. currentmodule:: pyhealth.metrics.generative

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have tasks updated with your EHR generation task.

.. autofunction:: evaluate_synthetic_ehr

Privacy metrics
-------------------------------------

.. autofunction:: calc_nnaar

.. autofunction:: calc_membership_inference

.. autofunction:: compute_discriminator_privacy

Utility and fidelity metrics
-------------------------------------

.. autofunction:: compute_mle

.. autofunction:: compute_prevalence_metrics
5 changes: 5 additions & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ API Reference
models/pyhealth.models.TFMTokenizer
models/pyhealth.models.GAN
models/pyhealth.models.VAE
models/pyhealth.models.HALO
models/pyhealth.models.GPT2
models/pyhealth.models.PromptEHR
models/pyhealth.models.MedGAN
models/pyhealth.models.CorGAN
models/pyhealth.models.SDOH
models/pyhealth.models.VisionEmbeddingModel
models/pyhealth.models.TextEmbedding
Expand Down
21 changes: 21 additions & 0 deletions docs/api/models/pyhealth.models.CorGAN.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
pyhealth.models.CorGAN
===================================

CorGAN: a Correlation-capturing Convolutional GAN for synthetic EHR generation.
A 1D-CNN (or linear) autoencoder captures local code correlations, and a WGAN
generator/critic are trained in the autoencoder's latent space. Ported from the
reference implementation
(`cor-gan <https://github.com/astorfi/cor-gan>`_) and wrapped as a PyHealth
:class:`~pyhealth.models.BaseModel`.

Reference:
Torfi, A., & Fox, E. A. (2020).
*CorGAN: Correlation-Capturing Convolutional Generative Adversarial
Networks for Generating Synthetic Healthcare Records.*
In Proceedings of the 33rd International FLAIRS Conference.
https://arxiv.org/abs/2001.09346

.. autoclass:: pyhealth.models.CorGAN
:members:
:undoc-members:
:show-inheritance:
17 changes: 17 additions & 0 deletions docs/api/models/pyhealth.models.GPT2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pyhealth.models.GPT2
===================================

A decoder-only GPT-2 baseline for unconditional synthetic EHR generation,
wrapped as a PyHealth :class:`~pyhealth.models.BaseModel`. Patient visit-code
sequences are serialized into causal-LM token streams and modeled
autoregressively.

Reference:
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019).
*Language Models are Unsupervised Multitask Learners.* OpenAI.
https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf

.. autoclass:: pyhealth.models.GPT2
:members:
:undoc-members:
:show-inheritance:
19 changes: 19 additions & 0 deletions docs/api/models/pyhealth.models.HALO.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pyhealth.models.HALO
===================================

HALO (Hierarchical Autoregressive Language model) for synthetic EHR generation.
A faithful port of the reference implementation
(`HALO_Inpatient <https://github.com/btheodorou99/HALO_Inpatient>`_),
wrapped as a PyHealth :class:`~pyhealth.models.BaseModel`.

Reference:
Theodorou, B., Xiao, C., & Sun, J. (2023).
*Synthesize high-dimensional longitudinal electronic health records via
hierarchical autoregressive language model.*
Nature Communications, 14, 5305.
https://www.nature.com/articles/s41467-023-41093-0

.. autoclass:: pyhealth.models.HALO
:members:
:undoc-members:
:show-inheritance:
21 changes: 21 additions & 0 deletions docs/api/models/pyhealth.models.MedGAN.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
pyhealth.models.MedGAN
===================================

MedGAN: a bag-of-codes Generative Adversarial Network for synthetic EHR
generation. An autoencoder is pre-trained on multi-hot patient records, then a
GAN with residual generator and minibatch-averaging discriminator is trained in
the autoencoder's latent space. Ported from the reference implementations
(`medgan <https://github.com/mp2893/medgan>`_ and its PyTorch reimplementation)
and wrapped as a PyHealth :class:`~pyhealth.models.BaseModel`.

Reference:
Choi, E., Biswal, S., Malin, B., Duke, J., Stewart, W. F., & Sun, J. (2017).
*Generating Multi-label Discrete Patient Records using Generative
Adversarial Networks.*
In Proceedings of Machine Learning for Healthcare (MLHC) 2017.
https://arxiv.org/abs/1703.06490

.. autoclass:: pyhealth.models.MedGAN
:members:
:undoc-members:
:show-inheritance:
20 changes: 20 additions & 0 deletions docs/api/models/pyhealth.models.PromptEHR.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pyhealth.models.PromptEHR
===================================

PromptEHR: prompt-learning BART for synthetic EHR generation. A port of the
reference implementation
(`PromptEHR <https://github.com/RyanWangZf/PromptEHR>`_) that consumes the
standard PyHealth interface and learns via a span-infilling objective with a
reparameterized soft prompt.

Reference:
Wang, Z., & Sun, J. (2022).
*PromptEHR: Conditional Electronic Healthcare Records Generation with
Prompt Learning.*
In Proceedings of EMNLP 2022.
https://aclanthology.org/2022.emnlp-main.185/

.. autoclass:: pyhealth.models.PromptEHR
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Available Tasks
COVID-19 CXR Classification <tasks/pyhealth.tasks.COVID19CXRClassification>
DKA Prediction (MIMIC-IV) <tasks/pyhealth.tasks.dka>
Drug Recommendation <tasks/pyhealth.tasks.drug_recommendation>
EHR Generation <tasks/pyhealth.tasks.generate_ehr>
Length of Stay Prediction <tasks/pyhealth.tasks.length_of_stay_prediction>
Medical Transcriptions Classification <tasks/pyhealth.tasks.MedicalTranscriptionsClassification>
Mortality Prediction (Next Visit) <tasks/pyhealth.tasks.mortality_prediction>
Expand Down
32 changes: 32 additions & 0 deletions docs/api/tasks/pyhealth.tasks.generate_ehr.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
pyhealth.tasks.generate_ehr
===========================================

Task that turns a longitudinal EHR dataset into per-patient, per-visit code
sequences for training unconditional synthetic-EHR generators (HALO, GPT2,
PromptEHR, MedGAN, CorGAN), plus helpers to flatten generated output into the
long-form dataframe consumed by :mod:`pyhealth.metrics.generative`.

Task Classes
------------

.. autoclass:: pyhealth.tasks.generate_ehr.EHRGeneration
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.tasks.generate_ehr.EHRGenerationMIMIC3
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.tasks.generate_ehr.EHRGenerationMIMIC4
:members:
:undoc-members:
:show-inheritance:

Helper Functions
----------------

.. autofunction:: pyhealth.tasks.generate_ehr.decode_dataset

.. autofunction:: pyhealth.tasks.generate_ehr.to_evaluation_dataframe
132 changes: 132 additions & 0 deletions examples/halo_mimic3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Example: train HALO on MIMIC-III and generate synthetic patients.

This example demonstrates:
1. Loading MIMIC-III data
2. Applying the EHRGenerationMIMIC3 task (per-visit ICD-9 code sequences)
3. Creating a SampleDataset with a NestedSequenceProcessor
4. Training the HALO generator with its custom training loop
5. Generating synthetic patients
6. Evaluating the synthetic data with the generative metrics suite
"""

import pandas as pd

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.metrics.generative import evaluate_synthetic_ehr
from pyhealth.models import HALO
from pyhealth.tasks import EHRGenerationMIMIC3

if __name__ == "__main__":
# STEP 1: Load MIMIC-III base dataset
base_dataset = MIMIC3Dataset(
root="/srv/local/data/MIMIC-III/mimic-iii-clinical-database-1.4",
tables=["diagnoses_icd"],
dev=True,
)

# STEP 2: Apply the EHR generation task (unconditional, no labels).
# This task is shared by all generators in pyhealth.models.generators.
sample_dataset = base_dataset.set_task(EHRGenerationMIMIC3())
print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

sample = sample_dataset[0]
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f" Visits tensor shape: {tuple(sample['visits'].shape)}")

# STEP 3: Split dataset by patient
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

# STEP 4: Initialize HALO (small config for the dev subset)
model = HALO(
dataset=sample_dataset,
embed_dim=128,
n_heads=4,
n_layers=4,
n_ctx=48,
batch_size=16,
epochs=5,
lr=1e-4,
save_dir="./halo_save",
)
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# STEP 5: Train with HALO's custom loop (saves best checkpoint to save_dir)
model.train_model(train_dataset, val_dataset=val_dataset)

# STEP 6: Generate synthetic patients (one per real training patient).
synthetic = model.generate(num_samples=len(train_dataset), random_sampling=True)
print("\nGenerated synthetic patients (first 3):")
for patient in synthetic[:3]:
print(f" {patient['patient_id']}: {len(patient['visits'])} visits")
print(f" {patient['visits']}")

# STEP 7: Evaluate the synthetic data with the generative metrics suite.
# evaluate_synthetic_ehr (and every metric it calls) expects flat /
# long-format dataframes -- ONE ROW PER (patient, visit, code) event --
# with four columns:
# - id patient identifier (any hashable; str here)
# - time visit index / timestep (sortable; int here)
# - visit_codes a SINGLE medical code (str or int; one per row,
# NOT a list/array -- a
# visit with k codes spans
# k rows)
# - labels per-patient binary label (0/1, int)
# train_df, test_df and syn_df below all share this exact schema. `labels`
# is a placeholder here: privacy metrics ignore it and the utility metric
# overwrites it with the next-visit prediction target.
index_to_code = {
v: k for k, v in sample_dataset.input_processors["visits"].code_vocab.items()
}

def real_subset_to_records(subset):
for sample in subset:
pid = str(sample["patient_id"])
visits_tensor = sample["visits"]
for t, visit in enumerate(visits_tensor.tolist()):
for idx in visit:
code = index_to_code.get(int(idx))
if code in (None, "<pad>", "<unk>"):
continue
yield {"id": pid, "time": t, "visit_codes": code, "labels": 0}

def synthetic_to_records(patients):
for p in patients:
pid = str(p["patient_id"])
for t, visit in enumerate(p["visits"]):
for code in visit:
yield {"id": pid, "time": t, "visit_codes": code, "labels": 0}

schema = {"visit_codes": str, "labels": int, "time": int, "id": str}
train_df = pd.DataFrame(real_subset_to_records(train_dataset)).astype(schema)
test_df = pd.DataFrame(real_subset_to_records(test_dataset)).astype(schema)
syn_df = pd.DataFrame(synthetic_to_records(synthetic)).astype(schema)
print(
f"\nEval rows -- train: {len(train_df)}, test: {len(test_df)}, "
f"synthetic: {len(syn_df)}"
)
# Show the flat schema: one row per (patient, visit, code) event.
print("\ntrain_df schema (one row per (patient, visit, code)):")
print(train_df.head())

# sample_size / n_bootstraps / n_runs are kept small for the dev subset;
# raise them when running on the full MIMIC-III cohort.
results = evaluate_synthetic_ehr(
train_ehr=train_df,
test_ehr=test_df,
syn_ehr=syn_df,
sample_size=min(30, len(train_dataset), len(test_dataset)),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why dataframes and why not use List[] instead?

mode="lstm",
metrics="all",
lstm_params={"embed_dim": 16, "hidden_dim": 16, "batch_size": 16, "epochs": 3},
n_bootstraps=5,
n_runs=3,
)
print("\nGenerative metrics (mean +/- std):")
for name, (mean, std) in results.items():
print(f" {name:30s} {mean:.4f} +/- {std:.4f}")
14 changes: 14 additions & 0 deletions pyhealth/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .binary import binary_metrics_fn
from .drug_recommendation import ddi_rate_score
from .generative import (
calc_membership_inference,
calc_nnaar,
compute_discriminator_privacy,
compute_mle,
compute_prevalence_metrics,
evaluate_synthetic_ehr,
)
from .interpretability import (
ComprehensivenessMetric,
Evaluator,
Expand All @@ -17,6 +25,12 @@
__all__ = [
"binary_metrics_fn",
"ddi_rate_score",
"calc_nnaar",
"calc_membership_inference",
"compute_discriminator_privacy",
"compute_mle",
"compute_prevalence_metrics",
"evaluate_synthetic_ehr",
"ComprehensivenessMetric",
"SufficiencyMetric",
"RemovalBasedMetric",
Expand Down
Loading
Loading