diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index c29955e7d..f7fa68974 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -70,6 +70,9 @@ def __init__(self, *args, **kwargs): from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset from .splitter import ( + assert_patient_disjoint, + check_patient_disjoint, + get_patient_ids, sample_balanced, split_by_patient, split_by_patient_conformal, diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index 2dbc94186..355b604c0 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -10,6 +10,121 @@ # TODO: add more splitting methods +def _patient_id_error(index: int, patient_id_key: str, reason: str) -> ValueError: + return ValueError(f"Sample {index} {reason} '{patient_id_key}'.") + + +def get_patient_ids(dataset_or_samples, patient_id_key: str = "patient_id") -> Set[str]: + """Returns patient IDs from a split, sample collection, or ID collection. + + This function inspects the actual samples in ``dataset_or_samples`` instead + of relying on ``patient_to_index``, which may describe the source dataset + rather than a derived split. + + Args: + dataset_or_samples: A ``SampleDataset``/subset, a list or tuple of sample + dictionaries, or a list/set/tuple of patient ID strings. + patient_id_key: Key used to read patient IDs from sample dictionaries. + + Returns: + A set of patient IDs. + + Raises: + ValueError: If any sample is missing ``patient_id_key`` or stores ``None``. + """ + + if isinstance(dataset_or_samples, str): + return {dataset_or_samples} + + if isinstance(dataset_or_samples, (list, tuple, set, frozenset)): + if all(not isinstance(item, dict) for item in dataset_or_samples): + patient_ids = set() + for index, patient_id in enumerate(dataset_or_samples): + if patient_id is None: + raise _patient_id_error(index, patient_id_key, "has None for") + patient_ids.add(str(patient_id)) + return patient_ids + + if hasattr(dataset_or_samples, "unique_patient_ids") and not hasattr( + dataset_or_samples, "__getitem__" + ): + return {str(patient_id) for patient_id in dataset_or_samples.unique_patient_ids} + + patient_ids = set() + if hasattr(dataset_or_samples, "__len__") and hasattr(dataset_or_samples, "__getitem__"): + samples = (dataset_or_samples[index] for index in range(len(dataset_or_samples))) + else: + samples = iter(dataset_or_samples) + + for index, sample in enumerate(samples): + if not isinstance(sample, dict) or patient_id_key not in sample: + raise _patient_id_error(index, patient_id_key, "is missing") + patient_id = sample[patient_id_key] + if patient_id is None: + raise _patient_id_error(index, patient_id_key, "has None for") + patient_ids.add(str(patient_id)) + return patient_ids + + +def check_patient_disjoint( + *datasets, + names: Optional[List[str]] = None, + patient_id_key: str = "patient_id", +) -> Dict[str, Any]: + """Checks whether patient IDs are disjoint across two or more splits.""" + + if len(datasets) < 2: + raise ValueError("At least two datasets or sample collections are required.") + if names is None: + names = [f"split_{index}" for index in range(len(datasets))] + if len(names) != len(datasets): + raise ValueError("names must have the same length as datasets.") + + patient_id_sets = [ + get_patient_ids(dataset, patient_id_key=patient_id_key) for dataset in datasets + ] + counts = { + name: len(patient_ids) for name, patient_ids in zip(names, patient_id_sets) + } + overlaps: Dict[str, Set[str]] = {} + + for i in range(len(patient_id_sets)): + for j in range(i + 1, len(patient_id_sets)): + overlap = patient_id_sets[i] & patient_id_sets[j] + if overlap: + overlaps[f"{names[i]}/{names[j]}"] = overlap + + return { + "is_disjoint": len(overlaps) == 0, + "counts": counts, + "overlaps": overlaps, + } + + +def assert_patient_disjoint( + *datasets, + names: Optional[List[str]] = None, + patient_id_key: str = "patient_id", +) -> Dict[str, Any]: + """Asserts that patient IDs are disjoint across two or more splits.""" + + report = check_patient_disjoint( + *datasets, names=names, patient_id_key=patient_id_key + ) + if report["is_disjoint"]: + return report + + parts = [] + for split_pair, overlap in report["overlaps"].items(): + examples = ", ".join(sorted(overlap)[:5]) + parts.append( + f"{split_pair}: {len(overlap)} overlapping " + f"{patient_id_key} values" + + (f" (examples: {examples})" if examples else "") + ) + raise AssertionError("Patient overlap detected between " + "; ".join(parts)) + + def _label_to_int(label) -> int: """Convert a stored label (int/np scalar/torch scalar) to Python int.""" if torch.is_tensor(label): diff --git a/tests/core/test_patient_disjoint.py b/tests/core/test_patient_disjoint.py new file mode 100644 index 000000000..08bcbea81 --- /dev/null +++ b/tests/core/test_patient_disjoint.py @@ -0,0 +1,93 @@ +import unittest + +from pyhealth.datasets import ( + assert_patient_disjoint, + check_patient_disjoint, + create_sample_dataset, + get_patient_ids, + split_by_patient, + split_by_patient_conformal, + split_by_sample, +) + + +def _make_dataset(patient_counts): + samples = [] + for patient_id, count in patient_counts: + for index in range(count): + samples.append( + { + "patient_id": patient_id, + "record_id": f"{patient_id}-{index}", + "label": index % 2, + } + ) + return create_sample_dataset( + samples=samples, + input_schema={}, + output_schema={"label": "binary"}, + in_memory=True, + ) + + +class TestPatientDisjoint(unittest.TestCase): + def test_split_by_patient_passes(self): + dataset = _make_dataset([(f"p{i}", 2) for i in range(6)]) + train, val, test = split_by_patient(dataset, [0.5, 0.25, 0.25], seed=0) + + report = assert_patient_disjoint( + train, val, test, names=["train", "val", "test"] + ) + + self.assertTrue(report["is_disjoint"]) + self.assertEqual(report["overlaps"], {}) + self.assertEqual(sum(report["counts"].values()), 6) + + def test_split_by_sample_can_fail_with_repeated_patients(self): + dataset = _make_dataset([("shared", 6), ("other_a", 1), ("other_b", 1)]) + train, val, test = split_by_sample(dataset, [0.5, 0.25, 0.25], seed=0) + + report = check_patient_disjoint( + train, val, test, names=["train", "val", "test"] + ) + + self.assertFalse(report["is_disjoint"]) + self.assertTrue(any("/" in key for key in report["overlaps"])) + self.assertIn("shared", set().union(*report["overlaps"].values())) + with self.assertRaisesRegex( + AssertionError, + "Patient overlap detected.*shared", + ): + assert_patient_disjoint( + train, val, test, names=["train", "val", "test"] + ) + + def test_split_by_patient_conformal_passes(self): + dataset = _make_dataset([(f"p{i}", 2) for i in range(8)]) + train, val, cal, test = split_by_patient_conformal( + dataset, [0.25, 0.25, 0.25, 0.25], seed=0 + ) + + report = assert_patient_disjoint( + train, val, cal, test, names=["train", "val", "cal", "test"] + ) + + self.assertTrue(report["is_disjoint"]) + self.assertEqual(report["overlaps"], {}) + self.assertEqual(sum(report["counts"].values()), 8) + + def test_missing_patient_id_error_is_readable(self): + with self.assertRaisesRegex(ValueError, "Sample 0.*patient_id"): + get_patient_ids([{"record_id": "r0", "label": 0}]) + + with self.assertRaisesRegex(ValueError, "Sample 1.*patient_id"): + get_patient_ids( + [ + {"patient_id": "p0", "label": 0}, + {"patient_id": None, "label": 1}, + ] + ) + + +if __name__ == "__main__": + unittest.main()