From 37c180c842377058413f7e1ceee928976b62eb80 Mon Sep 17 00:00:00 2001 From: axif Date: Sun, 7 Jun 2026 22:43:29 +0600 Subject: [PATCH 1/2] Add `postselection_mask` to `CompiledDetectorSampler.sample` for postselected QEC simulations --- CHANGELOG.md | 1 + docs/contrib.md | 4 + docs/index.md | 46 +++ src/tsim/sampler.py | 245 +++++++++++++- test/unit/test_postselection.py | 566 ++++++++++++++++++++++++++++++++ 5 files changed, 858 insertions(+), 4 deletions(-) create mode 100644 test/unit/test_postselection.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 125d4f9d..8b160892 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates. +- `CompiledDetectorSampler.sample` now accepts an optional `postselection_mask` argument for postselected simulations (#41). The mask has length `num_detectors`; a shot is discarded when any masked detector fires. When a discarded shot is flagged by a direct detector, the expensive JAX autoregressive loop is skipped for that draw while still returning one row per requested shot. Discarded rows retain their direct detector columns and fill all other columns with zero; callers recover surviving shots by re-applying the mask to the returned detector columns. Masks that only target non-direct detectors, or that mask no direct detectors, fall back to the standard sampling path. Fully-direct circuits continue to use the NumPy fast path. diff --git a/docs/contrib.md b/docs/contrib.md index cb4c06c9..aaf9f09a 100644 --- a/docs/contrib.md +++ b/docs/contrib.md @@ -46,3 +46,7 @@ uv run just doc ``` This will launch a local server to preview the documentation. You can also run `uv run just doc-build` to build the documentation without launching the server. + +Postselection support lives in `CompiledDetectorSampler.sample(postselection_mask=...)`. +See the **Postselected simulations** section in `docs/index.md` and unit tests in +`test/unit/test_postselection.py`. diff --git a/docs/index.md b/docs/index.md index f648655b..631216f5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -84,6 +84,52 @@ When set to `True`, a noiseless reference sample is computed and XORed with the results, so that output values represent deviations from the noiseless baseline. Note that this feature should be used carefully. If detectors or observables are not deterministic, this may lead to incorrect statistics. +## Postselected simulations + +For postselected QEC experiments, pass a boolean mask to +`CompiledDetectorSampler.sample`. The mask has length `num_detectors`; a shot is +*discarded* when any masked detector fires. + +```python +import numpy as np + +c = tsim.Circuit( + """ + X_ERROR(0.01) 0 1 + M 0 1 + DETECTOR rec[-2] + DETECTOR rec[-1] rec[-2] + OBSERVABLE_INCLUDE(0) rec[-1] + """ +) +sampler = c.compile_detector_sampler() +mask = np.array([True, False]) # postselect on detector 0 + +samples = sampler.sample( + shots=10_000, + postselection_mask=mask, + append_observables=True, +) +keep = ~np.any(samples[:, : c.num_detectors] & mask, axis=1) +survivors = samples[keep] +``` + +`sample` always returns exactly `shots` rows. Shots discarded by a **direct** +postselected detector skip the expensive JAX autoregressive loop; their direct +detector columns are still correct, and all other columns are filled with +`False`. Re-apply the mask to the detector columns (as above) to recover the +surviving shots. Detectors that live inside a JAX component cannot be evaluated +without running JAX, so those shots are always computed in full. + +This is independent of `prepend_observables`, `append_observables`, +`separate_observables`, and `bit_packed`. When combined with +`use_detector_reference_sample`, the reference XOR is applied before the +postselection discard check. On surviving rows it is applied to every detector +column; on direct-discarded partial rows it is applied only to direct detector +columns (component columns stay `False`). When combined with +`use_observable_reference_sample`, the reference XOR is applied to every row +that ran JAX; direct-discarded partial rows are left unchanged. + ## Benchmarks With GPU acceleration, Tsim can achieve sampling throughput for low-magic circuits that approaches the throughput of Stim on Clifford circuits of the same size. The figure below shows a comparison for [distillation circuits](https://arxiv.org/html/2412.15165v1) (35 and 85 qubits), [cultivation circuits](https://arxiv.org/abs/2409.17595), and rotated surface code circuits. diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index bfe4e746..80cf6cc1 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -224,6 +224,71 @@ def __init__( and not self._direct_flips.any() and np.array_equal(self._direct_f_indices, np.arange(n_direct)) ) + self._direct_global_indices = np.asarray( + prog.output_order[:n_direct], dtype=np.int32 + ) + self._direct_output_mask = np.zeros(prog.num_outputs, dtype=np.bool_) + if n_direct > 0: + self._direct_output_mask[self._direct_global_indices] = True + self._direct_detector_mask = self._direct_output_mask[ + : self._num_detectors + ].copy() + + def _compute_direct_outputs(self, f_params_np: np.ndarray) -> np.ndarray: + """Scatter direct output bits into a full (batch, num_outputs) bool array. + + Non-direct columns are zero. The zero-copy fast path applies when + direct indices are 0..n-1, there are no flips, and no reindex — + i.e. the common surface-code case. + """ + batch = f_params_np.shape[0] + num_outputs = self._program.num_outputs + n_direct = len(self._direct_f_indices) + if n_direct == 0: + return np.zeros((batch, num_outputs), dtype=np.bool_) + if self._direct_zero_copy and n_direct == num_outputs: + return f_params_np[:, :n_direct].view(np.bool_).copy() + raw = ( + f_params_np[:, :n_direct].view(np.bool_) + if self._direct_zero_copy + else (f_params_np[:, self._direct_f_indices] ^ self._direct_flips).view( + np.bool_ + ) + ) + out = np.zeros((batch, num_outputs), dtype=np.bool_) + out[:, self._direct_global_indices] = raw + return out + + def _compute_reference_sample(self) -> np.ndarray: + """Return the noiseless reference sample (all f_params = 0). + + Does not advance the channel sampler RNG. + """ + num_f = self._channel_sampler.signature_matrix.shape[1] + f_ref = np.zeros((1, num_f), dtype=np.uint8) + if not self._program.components: + return self._compute_direct_outputs(f_ref)[0] + self._key, subkey = jax.random.split(self._key) + return np.asarray( + sample_program(self._program, jnp.asarray(f_ref), subkey)[0], + dtype=np.bool_, + ) + + def _resolve_batch_size( + self, + shots: int, + batch_size: int | None, + *, + compute_reference: bool, + ) -> int: + """Choose a uniform JAX batch size for ``shots`` samples.""" + if batch_size is None: + max_batch_size = self._estimate_batch_size() + num_batches = max(1, ceil(shots / max_batch_size)) + batch_size = ceil(shots / num_batches) + if compute_reference and batch_size * ceil(shots / batch_size) == shots: + batch_size += 1 + return batch_size def _peak_bytes_per_sample(self) -> int: """Estimate peak device memory per sample from compiled program structure.""" @@ -301,8 +366,12 @@ def _sample_batches( return empty, np.zeros(self._program.num_outputs, dtype=np.bool_) return empty - if not self._program.components and not compute_reference: - return self._sample_direct(shots) + if not self._program.components: + samples = self._sample_direct(shots) + if compute_reference: + reference = self._compute_reference_sample() + return samples, reference + return samples if batch_size is None: max_batch_size = self._estimate_batch_size() @@ -342,6 +411,129 @@ def _sample_batches( return result, reference return result + def _sample_batches_with_postselection( + self, + shots: int, + batch_size: int | None, + *, + postselection_mask: np.ndarray, + compute_reference: bool = False, + xor_detector_ref: bool = False, + ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: + """Sample with postselection, skipping JAX for direct discarded shots. + + Shots discarded by a direct masked detector are filled with their + direct-column bits and ``False`` elsewhere; JAX is never called for + those shots. Survivors are buffered until a full batch of + ``batch_size`` is ready, then dispatched to ``sample_program`` in one + call. The final partial batch is padded to keep the JAX batch size + fixed (avoiding recompilation) and the padding rows are discarded. + """ + if shots < 0: + raise ValueError(f"shots must be non-negative, got {shots}") + if batch_size is not None and batch_size < 1: + raise ValueError(f"batch_size must be at least 1, got {batch_size}") + + num_outputs = self._program.num_outputs + if shots == 0: + empty = np.empty((0, num_outputs), dtype=np.bool_) + empty_discarded = np.empty(0, dtype=np.bool_) + if compute_reference: + return empty, np.zeros(num_outputs, dtype=np.bool_), empty_discarded + return empty, None, empty_discarded + + postselect_direct = postselection_mask & self._direct_detector_mask + + if not self._program.components: + samples = self._sample_direct(shots) + if compute_reference: + reference = self._compute_reference_sample() + if xor_detector_ref: + samples[:, : self._num_detectors] ^= reference[ + : self._num_detectors + ] + return samples, reference, np.zeros(shots, dtype=np.bool_) + return samples, None, np.zeros(shots, dtype=np.bool_) + + if batch_size is None: + batch_size = self._resolve_batch_size( + shots, batch_size, compute_reference=False + ) + + reference: np.ndarray | None = None + if compute_reference: + reference = self._compute_reference_sample() + + result = np.zeros((shots, num_outputs), dtype=np.bool_) + was_discarded = np.zeros(shots, dtype=np.bool_) + survivor_f_buf: list[np.ndarray] = [] + survivor_idx_buf: list[int] = [] + shot_idx = 0 + + def _dispatch(f_batch: np.ndarray, indices: list[int], n_valid: int) -> None: + self._key, subkey = jax.random.split(self._key) + jax_out = np.asarray( + sample_program(self._program, jnp.asarray(f_batch), subkey) + ) + result[indices[:n_valid]] = jax_out[:n_valid] + + def _flush(*, final: bool = False) -> None: + nonlocal survivor_f_buf, survivor_idx_buf + while len(survivor_f_buf) >= batch_size: + _dispatch( + np.stack(survivor_f_buf[:batch_size]), + survivor_idx_buf[:batch_size], + batch_size, + ) + survivor_f_buf = survivor_f_buf[batch_size:] + survivor_idx_buf = survivor_idx_buf[batch_size:] + + if final and survivor_f_buf: + n_valid = len(survivor_f_buf) + f_stack = np.stack(survivor_f_buf) + f_batch = np.empty((batch_size, f_stack.shape[1]), dtype=f_stack.dtype) + f_batch[:n_valid] = f_stack + f_batch[n_valid:] = f_stack[0] + _dispatch(f_batch, survivor_idx_buf, n_valid) + survivor_f_buf = [] + survivor_idx_buf = [] + + while shot_idx < shots: + chunk = min(batch_size, shots - shot_idx) + f_params_np = self._channel_sampler.sample(chunk) + direct_full = self._compute_direct_outputs(f_params_np) + det_cols = direct_full[:, : self._num_detectors] + if xor_detector_ref and reference is not None: + det_cols = det_cols ^ reference[: self._num_detectors] + + discarded = (det_cols & postselect_direct).any(axis=1) + + result[shot_idx : shot_idx + chunk] = direct_full + was_discarded[shot_idx : shot_idx + chunk] = discarded + + survivor_local = np.flatnonzero(~discarded) + if survivor_local.size: + survivor_f_buf.extend(f_params_np[survivor_local]) + survivor_idx_buf.extend((shot_idx + survivor_local).tolist()) + + shot_idx += chunk + _flush() + + _flush(final=True) + + if xor_detector_ref and reference is not None: + det_ref = reference[: self._num_detectors] + survivors = ~was_discarded + result[survivors, : self._num_detectors] ^= det_ref + result[was_discarded, : self._num_detectors] ^= ( + det_ref & self._direct_detector_mask + ) + + if compute_reference: + assert reference is not None + return result, reference, was_discarded + return result, None, was_discarded + def _sample_direct(self, shots: int) -> np.ndarray: """Fast path when all components are direct (pure numpy, no JAX).""" f_params = self._channel_sampler.sample(shots) @@ -509,6 +701,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: ... @overload @@ -523,6 +716,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> np.ndarray: ... def sample( @@ -536,6 +730,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Return detector samples from the circuit. @@ -568,6 +763,12 @@ def sample( results represent deviations from the noiseless baseline. This should only be used when observables are deterministic. Otherwise, it can unpredictably change the results. + postselection_mask: Optional boolean array of length ``num_detectors``. + When set, shots where any masked direct detector fires skip the JAX + sampling loop. All ``shots`` rows are still returned: survivors contain + the full sample, while discarded rows retain direct detector columns + and fill component columns with ``False``. Re-apply the mask to the + returned detector columns to recover surviving shots. Returns: A numpy array or tuple of numpy arrays containing the samples. @@ -587,7 +788,44 @@ def sample( use_detector_reference_sample or use_observable_reference_sample ) - if compute_reference: + if postselection_mask is not None: + mask = np.asarray(postselection_mask, dtype=np.bool_) + if mask.shape != (self._num_detectors,): + raise ValueError( + f"postselection_mask must have shape ({self._num_detectors},), " + f"got {mask.shape}" + ) + if postselection_mask is not mask: + postselection_mask = mask + if ( + not (postselection_mask & self._direct_detector_mask).any() + or not self._program.components + ): + postselection_mask = None + + if postselection_mask is not None: + if compute_reference: + samples, reference, direct_discarded = ( + self._sample_batches_with_postselection( + shots, + batch_size, + postselection_mask=postselection_mask, + compute_reference=True, + xor_detector_ref=use_detector_reference_sample, + ) + ) + assert reference is not None + num_detectors = self._num_detectors + if use_observable_reference_sample: + obs_ref = reference[num_detectors:] + samples[~direct_discarded, num_detectors:] ^= obs_ref + else: + samples, _, _ = self._sample_batches_with_postselection( + shots, + batch_size, + postselection_mask=postselection_mask, + ) + elif compute_reference: samples, reference = self._sample_batches( shots, batch_size, compute_reference=True ) @@ -618,7 +856,6 @@ def sample( ) return _maybe_bit_pack(det_samples, bit_packed=bit_packed) - # TODO: don't compute observables if they are discarded here class CompiledStateProbs(_CompiledSamplerBase): diff --git a/test/unit/test_postselection.py b/test/unit/test_postselection.py new file mode 100644 index 00000000..00c7551c --- /dev/null +++ b/test/unit/test_postselection.py @@ -0,0 +1,566 @@ +"""Unit tests for CompiledDetectorSampler.sample postselection_mask feature.""" + +from __future__ import annotations + +from unittest.mock import patch + +import numpy as np +import pytest +import stim + +import tsim.sampler as sampler_module +from tsim.circuit import Circuit + +# ────────────────────────── shared circuits ────────────────────────────────── + +# Detector 0 is direct (single X_ERROR -> M -> DETECTOR). +# Detector 1 is a compiled component (DETECTOR rec[-1] rec[-1] is trivially 0 +# but involves a JAX component because it entangles with the ZX diagram). +MIXED_DIRECT_CIRCUIT = """ +X_ERROR(0.5) 0 +R 1 +H 1 +M 0 1 +DETECTOR rec[-2] +DETECTOR rec[-1] rec[-2] +""" + +FULLY_DIRECT_CIRCUIT = """ +X_ERROR(0.5) 0 +M 0 +DETECTOR rec[-1] +""" + +ALWAYS_DISCARD_CIRCUIT = """ +X_ERROR(1) 0 +R 1 +H 1 +M 0 1 +DETECTOR rec[-2] +DETECTOR rec[-1] rec[-2] +""" + +ALWAYS_DISCARD_OBS_CIRCUIT = """ +X_ERROR(1) 0 +R 1 +H 1 +M 0 1 +DETECTOR rec[-2] +DETECTOR rec[-1] rec[-2] +OBSERVABLE_INCLUDE(0) rec[-1] +""" + +# Circuit with detectors and an observable for output-layout tests. +DET_OBS_CIRCUIT = """ +R 0 1 2 +X 2 +M 0 1 2 +DETECTOR rec[-2] +DETECTOR rec[-3] +OBSERVABLE_INCLUDE(0) rec[-1] +""" + + +def _make(circuit_str: str, seed: int = 0): + return Circuit(circuit_str).compile_detector_sampler(seed=seed) + + +def _keep(samples: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Boolean row mask of shots not discarded by postselection.""" + return ~np.any(samples & mask, axis=1) + + +# ────────────────────────── validation ─────────────────────────────────────── + + +def test_postselection_mask_wrong_length_raises(): + sampler = _make(MIXED_DIRECT_CIRCUIT) + with pytest.raises(ValueError, match="postselection_mask must have shape"): + sampler.sample(1, postselection_mask=np.array([True, False, False])) + + +def test_postselection_mask_wrong_ndim_raises(): + sampler = _make(MIXED_DIRECT_CIRCUIT) + with pytest.raises(ValueError, match="postselection_mask must have shape"): + sampler.sample(1, postselection_mask=np.zeros((2, 1), dtype=np.bool_)) + + +def test_postselection_negative_shots_raises(): + sampler = _make(MIXED_DIRECT_CIRCUIT) + with pytest.raises(ValueError, match="shots must be non-negative"): + sampler.sample(-1, postselection_mask=np.array([True, False])) + + +def test_postselection_invalid_batch_size_raises(): + sampler = _make(MIXED_DIRECT_CIRCUIT) + with pytest.raises(ValueError, match="batch_size must be at least 1"): + sampler.sample(1, batch_size=0, postselection_mask=np.array([True, False])) + + +# ────────────────────────── basic shape / identity ─────────────────────────── + + +def test_postselection_none_matches_default(): + """postselection_mask=None must be bit-identical to omitting the argument.""" + a = _make(MIXED_DIRECT_CIRCUIT, seed=5).sample(16, batch_size=4) + b = _make(MIXED_DIRECT_CIRCUIT, seed=5).sample( + 16, batch_size=4, postselection_mask=None + ) + assert np.array_equal(a, b) + + +def test_postselection_return_shape_preserved(): + """Always return exactly (shots, num_detectors).""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=0) + mask = np.array([True, False]) + assert sampler.sample(0, postselection_mask=mask).shape == (0, 2) + assert sampler.sample(1, postselection_mask=mask).shape == (1, 2) + assert sampler.sample(17, batch_size=4, postselection_mask=mask).shape == (17, 2) + + +def test_postselection_zero_shots(): + sampler = _make(MIXED_DIRECT_CIRCUIT) + mask = np.array([True, False]) + assert sampler.sample(0, postselection_mask=mask).shape == (0, 2) + + +def test_postselection_all_false_mask_matches_default(): + """All-False mask → no JAX skipped; survivors == all shots.""" + mask = np.zeros(2, dtype=np.bool_) + a = _make(MIXED_DIRECT_CIRCUIT, seed=7).sample(20, batch_size=5) + b = _make(MIXED_DIRECT_CIRCUIT, seed=7).sample( + 20, batch_size=5, postselection_mask=mask + ) + assert np.array_equal(a, b) + + +# ────────────────────────── discard / partial-row semantics ────────────────── + + +def test_postselection_discarded_rows_component_cols_false(): + """Discarded rows: direct col truthful, component cols all False.""" + sampler = _make(ALWAYS_DISCARD_CIRCUIT, seed=0) + mask = np.array([True, False]) + samples = sampler.sample(20, batch_size=4, postselection_mask=mask) + + # All shots discarded because det0 always fires. + assert np.all(samples[:, 0]) + assert np.all(~samples[:, 1]) + + +def test_postselection_discarded_and_surviving_rows(): + """With 50% noise, both discarded and surviving rows appear.""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=2) + mask = np.array([True, False]) + samples = sampler.sample(64, batch_size=8, postselection_mask=mask) + + discarded = samples[:, 0] & mask[0] + assert discarded.any(), "expected some discards with 50% noise" + assert (~discarded).any(), "expected some survivors" + + # Component col False for every discarded row. + assert np.all(~samples[discarded, 1]) + + +def test_postselection_direct_cols_always_equal_numpy(): + """Direct output columns match NumPy computation for every row (discard or not).""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=3) + mask = np.array([True, False]) + drawn: list[np.ndarray] = [] + original = sampler._channel_sampler.sample + + def capture(n: int) -> np.ndarray: + batch = original(n) + drawn.append(batch.copy()) + return batch + + with patch.object(sampler._channel_sampler, "sample", side_effect=capture): + samples = sampler.sample(8, batch_size=4, postselection_mask=mask) + + f_all = np.concatenate(drawn) + expected_direct = sampler._compute_direct_outputs(f_all) + assert np.array_equal(samples & sampler._direct_output_mask, expected_direct) + + +# ────────────────────────── JAX-skip behaviour ─────────────────────────────── + + +def test_postselection_jax_never_called_for_all_direct_discards(monkeypatch): + """When every shot is discarded by a direct detector, sample_program is never called.""" + sampler = _make(ALWAYS_DISCARD_CIRCUIT, seed=0) + mask = np.array([True, False]) + calls: list[int] = [] + + original = sampler_module.sample_program + + def spy(program, f_params, key): + calls.append(f_params.shape[0]) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", spy) + sampler.sample(10, batch_size=4, postselection_mask=mask) + assert calls == [] + + +def test_postselection_jax_rows_less_than_shots(): + """Total JAX rows < shots when some shots are direct-discarded.""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=0) + mask = np.array([True, False]) + jax_rows: list[int] = [] + + original = sampler_module.sample_program + + def spy(program, f_params, key): + jax_rows.append(f_params.shape[0]) + return original(program, f_params, key) + + with patch.object(sampler_module, "sample_program", side_effect=spy): + samples = sampler.sample(32, batch_size=8, postselection_mask=mask) + + discarded = np.any(samples[:, : sampler._num_detectors] & mask, axis=1) + assert sum(jax_rows) < 32 + assert sum(jax_rows) >= int((~discarded).sum()) + + +def test_postselection_jax_batch_size_fixed(monkeypatch): + """Every JAX call uses the same batch_size (no recompilation).""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=4) + mask = np.array([True, False]) + seen: list[int] = [] + + original = sampler_module.sample_program + + def spy(program, f_params, key): + seen.append(f_params.shape[0]) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", spy) + sampler.sample(10, batch_size=4, postselection_mask=mask) + + assert seen, "expected at least one JAX call for surviving shots" + assert all(b == 4 for b in seen), f"non-uniform batch sizes: {seen}" + + +def test_postselection_non_direct_mask_runs_jax_for_all(monkeypatch): + """Mask on non-direct detector only → JAX runs for every shot.""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=9) + mask = np.array([False, True]) # det1 is a component + jax_rows: list[int] = [] + + original = sampler_module.sample_program + + def spy(program, f_params, key): + jax_rows.append(f_params.shape[0]) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", spy) + sampler.sample(16, batch_size=8, postselection_mask=mask) + assert sum(jax_rows) == 16 + + +def test_postselection_mixed_mask_skips_jax_only_on_direct_discard(monkeypatch): + """Mixed mask: JAX skipped only when a masked direct detector fires.""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=1) + mask = np.array([True, True]) + jax_rows: list[int] = [] + + original = sampler_module.sample_program + + def spy(program, f_params, key): + jax_rows.append(f_params.shape[0]) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", spy) + samples = sampler.sample(32, batch_size=8, postselection_mask=mask) + direct_discarded = samples[:, 0] & mask[0] + assert direct_discarded.any() + assert sum(jax_rows) < 32 + assert sum(jax_rows) >= int((~direct_discarded).sum()) + + +def test_postselection_detector_reference_xor_before_discard_check(monkeypatch): + """Discard check uses XOR'd detectors; ref fire cancels raw fire on direct det.""" + circuit = """ + X 0 + X_ERROR(0.5) 0 + R 1 + H 1 + M 0 1 + DETECTOR rec[-2] + DETECTOR rec[-1] rec[-2] + """ + mask = np.array([True, False]) + jax_without: list[int] = [] + jax_with: list[int] = [] + original = sampler_module.sample_program + + def make_spy(store: list[int]): + def spy(program, f_params, key): + store.append(f_params.shape[0]) + return original(program, f_params, key) + + return spy + + s1 = _make(circuit, seed=0) + monkeypatch.setattr(sampler_module, "sample_program", make_spy(jax_without)) + s1.sample(32, batch_size=8, postselection_mask=mask) + + s2 = _make(circuit, seed=0) + monkeypatch.setattr(sampler_module, "sample_program", make_spy(jax_with)) + s2.sample( + 32, + batch_size=8, + postselection_mask=mask, + use_detector_reference_sample=True, + ) + + assert sum(jax_with) > sum(jax_without) + + +def test_postselection_observable_reference_on_jax_computed_discarded_rows(): + """Observable ref XOR applies to every row that ran JAX, including caller discards.""" + circuit = """ + X 0 + X_ERROR(0.5) 0 + R 1 + H 1 + M 0 1 + DETECTOR rec[-2] + DETECTOR rec[-1] rec[-2] + OBSERVABLE_INCLUDE(0) rec[-1] + """ + mask = np.array([True, True]) + kwargs = { + "batch_size": 8, + "use_observable_reference_sample": True, + "separate_observables": True, + } + sampler = _make(circuit, seed=2) + captured: dict[str, np.ndarray] = {} + original = sampler._sample_batches_with_postselection + + def capture(*args, **kwargs): + samples, reference, direct_discarded = original(*args, **kwargs) + assert reference is not None + captured["raw_obs"] = samples[:, sampler._num_detectors :].copy() + captured["direct_discarded"] = direct_discarded.copy() + captured["obs_ref"] = reference[sampler._num_detectors :].copy() + return samples, reference, direct_discarded + + sampler._sample_batches_with_postselection = capture + _dets, obs = sampler.sample(128, postselection_mask=mask, **kwargs) + + raw_obs = captured["raw_obs"] + direct_discarded = captured["direct_discarded"] + obs_ref = captured["obs_ref"] + expected = raw_obs.copy() + expected[~direct_discarded] ^= obs_ref + + assert np.all(obs_ref) + assert np.array_equal(obs, expected) + assert not np.any(obs[direct_discarded]) + ran_jax = ~direct_discarded + assert np.any(obs[ran_jax] != raw_obs[ran_jax]) + + +# ────────────────────────── fully-direct fast path ─────────────────────────── + + +def test_postselection_fully_direct_no_jax(monkeypatch): + """Fully-direct circuits never call sample_program.""" + sampler = _make(FULLY_DIRECT_CIRCUIT, seed=0) + mask = np.array([True]) + spy = [] + + def counting_sp(program, f_params, key): + spy.append(1) + return sampler_module.sample_program(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", counting_sp) + result = sampler.sample(10, postselection_mask=mask) + assert result.shape == (10, 1) + assert spy == [] + + +def test_postselection_fully_direct_matches_default(): + """Fully-direct + all-False mask → identical to default sampling.""" + a = _make(FULLY_DIRECT_CIRCUIT, seed=11).sample(50) + b = _make(FULLY_DIRECT_CIRCUIT, seed=11).sample( + 50, postselection_mask=np.zeros(1, dtype=np.bool_) + ) + assert np.array_equal(a, b) + + +def test_postselection_fully_direct_detector_reference(): + """Fully-direct circuits fall back to the standard path, including detector ref.""" + mask = np.array([True]) + with_mask = _make(FULLY_DIRECT_CIRCUIT, seed=0).sample( + 12, postselection_mask=mask, use_detector_reference_sample=True + ) + without = _make(FULLY_DIRECT_CIRCUIT, seed=0).sample( + 12, use_detector_reference_sample=True + ) + assert np.array_equal(with_mask, without) + + +# ────────────────────────── reference sample interaction ───────────────────── + + +def test_postselection_with_detector_reference_no_crash(): + """use_detector_reference_sample combined with postselection_mask must not raise.""" + sampler = _make(DET_OBS_CIRCUIT, seed=0) + mask = np.zeros(2, dtype=np.bool_) + result = sampler.sample( + 8, postselection_mask=mask, use_detector_reference_sample=True + ) + assert result.shape == (8, 2) + + +def test_postselection_detector_reference_matches_unmasked(): + """All-false mask + detector ref must match sampling without postselection.""" + mask = np.zeros(2, dtype=np.bool_) + kwargs = {"batch_size": 4, "use_detector_reference_sample": True} + with_ref = _make(MIXED_DIRECT_CIRCUIT, seed=0).sample( + 24, postselection_mask=mask, **kwargs + ) + without = _make(MIXED_DIRECT_CIRCUIT, seed=0).sample(24, **kwargs) + assert np.array_equal(with_ref, without) + + +def test_postselection_detector_reference_survivors_and_discarded(): + """Detector ref XOR applies to both survivor and discarded rows.""" + mask = np.array([True, False]) + kwargs = {"batch_size": 8, "use_detector_reference_sample": True} + samples = _make(MIXED_DIRECT_CIRCUIT, seed=3).sample( + 64, postselection_mask=mask, **kwargs + ) + keep = _keep(samples, mask) + assert keep.any() and (~keep).any() + assert not np.any(samples[keep] & mask) + assert np.all(samples[~keep, 0]) + assert np.all(~samples[~keep, 1]) + + +def test_postselection_reference_does_not_advance_channel_rng(): + """_compute_reference_sample must not draw from channel_sampler RNG.""" + sampler = _make(MIXED_DIRECT_CIRCUIT, seed=0) + original = sampler._channel_sampler.sample + calls: list[int] = [] + + def spy(n: int) -> np.ndarray: + calls.append(n) + return original(n) + + with patch.object(sampler._channel_sampler, "sample", side_effect=spy): + sampler._compute_reference_sample() + + assert calls == [], ( + "_compute_reference_sample must not call channel_sampler.sample; " + f"got calls {calls}" + ) + + +def test_postselection_with_observable_reference(): + sampler = _make(DET_OBS_CIRCUIT, seed=3) + mask = np.zeros(2, dtype=np.bool_) + dets, obs = sampler.sample( + 8, + postselection_mask=mask, + separate_observables=True, + use_observable_reference_sample=True, + ) + assert dets.shape == (8, 2) + assert obs.shape == (8, 1) + + +def test_postselection_observable_reference_skipped_on_discarded(): + """Observable ref XOR must not fill discarded rows' uncomputed obs columns.""" + mask = np.array([True, False]) + dets, obs = _make(ALWAYS_DISCARD_OBS_CIRCUIT, seed=1).sample( + 16, + batch_size=8, + postselection_mask=mask, + separate_observables=True, + use_observable_reference_sample=True, + ) + discarded = np.any(dets & mask, axis=1) + assert np.all(discarded) + assert not np.any(obs) + + +# ────────────────────────── output-layout flags ────────────────────────────── + + +def test_postselection_output_layout_append_observables(): + sampler = _make(DET_OBS_CIRCUIT, seed=0) + mask = np.array([True, False]) + result = sampler.sample(4, postselection_mask=mask, append_observables=True) + assert result.shape == (4, 3) + + +def test_postselection_output_layout_prepend_observables(): + sampler = _make(DET_OBS_CIRCUIT, seed=0) + mask = np.array([True, False]) + result = sampler.sample(4, postselection_mask=mask, prepend_observables=True) + assert result.shape == (4, 3) + + +def test_postselection_output_layout_separate_observables(): + sampler = _make(DET_OBS_CIRCUIT, seed=0) + mask = np.array([True, False]) + dets, obs = sampler.sample(4, postselection_mask=mask, separate_observables=True) + assert dets.shape == (4, 2) + assert obs.shape == (4, 1) + + +def test_postselection_output_layout_bit_packed(): + sampler = _make(DET_OBS_CIRCUIT, seed=0) + mask = np.array([True, False]) + result = sampler.sample(4, postselection_mask=mask, bit_packed=True) + assert result.dtype == np.uint8 + assert result.shape == (4, 1) + + +# ────────────────────────── surface-code integration ───────────────────────── + + +def test_postselection_surface_code_fully_direct_unchanged(): + """Surface-code detectors are all direct; mask=0 must not change samples.""" + circ = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=3, + rounds=2, + after_clifford_depolarization=0.01, + ) + c = Circuit.from_stim_program(circ) + assert ( + c.compile_detector_sampler()._direct_detector_mask.all() + ), "expected all detectors direct for this circuit" + + mask = np.zeros(c.num_detectors, dtype=np.bool_) + a = c.compile_detector_sampler(seed=0).sample(100, batch_size=16) + b = c.compile_detector_sampler(seed=0).sample( + 100, batch_size=16, postselection_mask=mask + ) + assert np.array_equal(a, b) + + +def test_postselection_surface_code_caller_filter(): + """After postselection, caller filters survivors via mask; they have no fired detectors.""" + circ = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=3, + rounds=2, + after_clifford_depolarization=0.01, + ) + c = Circuit.from_stim_program(circ) + num_det = c.num_detectors + mask = np.zeros(num_det, dtype=np.bool_) + mask[0] = True + + samples = c.compile_detector_sampler(seed=0).sample( + 200, batch_size=32, postselection_mask=mask + ) + survivors = _keep(samples[:, :num_det], mask) + assert survivors.any() + assert not np.any(samples[survivors] & mask) From 2ada4f7376ce870971badf94f3568eaa4aa7522d Mon Sep 17 00:00:00 2001 From: axif Date: Sun, 7 Jun 2026 22:55:37 +0600 Subject: [PATCH 2/2] Refactor `_CompiledSamplerBase` to clarify direct detector column handling --- src/tsim/sampler.py | 6 ++++-- test/unit/test_postselection.py | 35 +++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index 80cf6cc1..dd6c72be 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -423,7 +423,7 @@ def _sample_batches_with_postselection( """Sample with postselection, skipping JAX for direct discarded shots. Shots discarded by a direct masked detector are filled with their - direct-column bits and ``False`` elsewhere; JAX is never called for + direct detector columns and ``False`` elsewhere; JAX is never called for those shots. Survivors are buffered until a full batch of ``batch_size`` is ready, then dispatched to ``sample_program`` in one call. The final partial batch is padded to keep the JAX batch size @@ -508,7 +508,9 @@ def _flush(*, final: bool = False) -> None: discarded = (det_cols & postselect_direct).any(axis=1) - result[shot_idx : shot_idx + chunk] = direct_full + result[shot_idx : shot_idx + chunk, : self._num_detectors] = direct_full[ + :, : self._num_detectors + ] was_discarded[shot_idx : shot_idx + chunk] = discarded survivor_local = np.flatnonzero(~discarded) diff --git a/test/unit/test_postselection.py b/test/unit/test_postselection.py index 00c7551c..48d648f7 100644 --- a/test/unit/test_postselection.py +++ b/test/unit/test_postselection.py @@ -163,7 +163,7 @@ def test_postselection_discarded_and_surviving_rows(): def test_postselection_direct_cols_always_equal_numpy(): - """Direct output columns match NumPy computation for every row (discard or not).""" + """Direct detector columns match NumPy computation for every row.""" sampler = _make(MIXED_DIRECT_CIRCUIT, seed=3) mask = np.array([True, False]) drawn: list[np.ndarray] = [] @@ -179,7 +179,11 @@ def capture(n: int) -> np.ndarray: f_all = np.concatenate(drawn) expected_direct = sampler._compute_direct_outputs(f_all) - assert np.array_equal(samples & sampler._direct_output_mask, expected_direct) + direct_det = sampler._direct_detector_mask + assert np.array_equal( + samples[:, : sampler._num_detectors] & direct_det, + expected_direct[:, : sampler._num_detectors] & direct_det, + ) # ────────────────────────── JAX-skip behaviour ─────────────────────────────── @@ -564,3 +568,30 @@ def test_postselection_surface_code_caller_filter(): survivors = _keep(samples[:, :num_det], mask) assert survivors.any() assert not np.any(samples[survivors] & mask) + + +def test_postselection_discarded_rows_zero_direct_observable(): + """Discarded rows keep direct detector cols but zero direct observable cols.""" + circ = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=3, + rounds=2, + after_clifford_depolarization=0.01, + ) + c = Circuit.from_stim_program(circ) + sampler = c.compile_detector_sampler(seed=0) + assert sampler._direct_output_mask[sampler._num_detectors :].any(), ( + "expected a direct observable for this circuit" + ) + + mask = np.zeros(c.num_detectors, dtype=np.bool_) + mask[0] = True + dets, obs = sampler.sample( + 64, + batch_size=16, + postselection_mask=mask, + separate_observables=True, + ) + discarded = np.any(dets & mask, axis=1) + assert discarded.any() and (~discarded).any() + assert not np.any(obs[discarded])