Skip to content
Open
39 changes: 35 additions & 4 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,39 @@
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
einops_rearrange, _ = optional_import("einops", name="rearrange")
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}


class _RearrangeFn(nn.Module):
"""Thin wrapper around einops.rearrange function, used as a fallback when the
Rearrange layer rejects axis-size kwargs in some einops builds."""

def __init__(self, pattern: str, axes_lengths: dict[str, int]) -> None:
"""Initialize the functional-rearrange fallback module.

Args:
pattern: Einops rearrangement pattern.
axes_lengths: Named axis lengths consumed by ``einops.rearrange``.
"""
super().__init__()
self.pattern = pattern
self.axes_lengths = axes_lengths

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Rearrange an input tensor using the stored einops pattern.

Args:
x: Input tensor.

Returns:
Rearranged tensor produced by ``einops.rearrange``.
"""
out: torch.Tensor = einops_rearrange(x, self.pattern, **self.axes_lengths)
return out


class PatchEmbeddingBlock(nn.Module):
"""
A patch embedding block, based on: "Dosovitskiy et al.,
Expand Down Expand Up @@ -97,14 +126,16 @@ def __init__(
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.proj_type == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
# for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)}
self.patch_embeddings = nn.Sequential(
Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
)
try:
rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len)
except TypeError:
rearrange_layer = _RearrangeFn(f"{from_chars} -> {to_chars}", axes_len)
self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size))
Comment on lines +129 to +138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

einops Rearrange pattern integer literals support 0.8

💡 Result:

No, einops Rearrange patterns do not support standalone integer literals (e.g., 'b 2 h w') as axis specifications. Patterns use symbolic axis names (letters like 'b', 'h'), anonymous axes (numbers in parentheses like '(b1 b2)'), ellipsis (...), or underscores (_ for skipping). Integer values are provided via the axes_lengths keyword argument (e.g., rearrange(x, 'b (h1 h) w c -> ...', h1=2)) or parsed from tensor shapes via parse_shape. This syntax has been consistent across versions, including 0.8.0+ (released 2024-04-28), with no changes introducing integer literal support in patterns per release notes, changelogs, and docs. Examples in official documentation and tutorials confirm named axes only.

Citations:


🏁 Script executed:

# Get the file and examine _PatchRearrange class and test coverage
head -150 monai/networks/blocks/patchembedding.py | tail -100

Repository: Project-MONAI/MONAI

Length of output: 4640


🏁 Script executed:

# Check for existing tests covering the fallback path
find . -name "*.py" -path "*/test*" -type f | xargs grep -l "patchembedding" | head -5

Repository: Project-MONAI/MONAI

Length of output: 130


🏁 Script executed:

# Get the _PatchRearrange class definition (around lines 32-55)
sed -n '32,56p' monai/networks/blocks/patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 958


🏁 Script executed:

# Check the test file for coverage of _PatchRearrange and the fallback path
cat monai/tests/networks/blocks/test_patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 145


🏁 Script executed:

# Find the test file with correct path
find . -name "test_patchembedding.py" -type f 2>/dev/null

Repository: Project-MONAI/MONAI

Length of output: 109


🏁 Script executed:

# Also check for any tests that import or test _PatchRearrange
rg "_PatchRearrange" --type py

Repository: Project-MONAI/MONAI

Length of output: 275


🏁 Script executed:

# Read the test file
cat ./tests/networks/blocks/test_patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 7510


🏁 Script executed:

# Check einops version constraints and any comments about the fallback logic
head -30 monai/networks/blocks/patchembedding.py

Repository: Project-MONAI/MONAI

Length of output: 1308


🏁 Script executed:

# Verify if x.view() in _PatchRearrange could fail on non-contiguous tensors
rg "x\.view\(" monai/networks/blocks/patchembedding.py -A 2 -B 2

Repository: Project-MONAI/MONAI

Length of output: 233


Retract the simplification suggestion; the current try/except approach is necessary.

einops Rearrange does not support integer literals in patterns—numeric axis values must be passed via axes_lengths (the current axes_len dict). The suggested approach of embedding integers as literals (e.g., "b c (h 16) (w 16) (d 16)") is not feasible. The fallback with _PatchRearrange is the correct strategy for version compatibility.

However, address these remaining issues in _PatchRearrange:

  1. Missing Google-style docstrings: Add docstrings to __init__ and forward methods describing arguments, return values, and behavior per coding guidelines.
  2. Type hint specificity: Change patch_size: tuple to patch_size: tuple[int, ...].
  3. Use reshape() instead of view(): Line 47 uses x.view(*v), which fails on non-contiguous tensors; reshape() is safer.
  4. Incomplete test coverage: The test suite only exercises the Rearrange path (since einops is installed). The fallback is never deterministically validated. Add a test that directly instantiates and tests _PatchRearrange independently.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/blocks/patchembedding.py` around lines 126 - 135, The
_PatchRearrange fallback needs fixes: add Google-style docstrings to the class
methods __init__ and forward describing arguments, return values, and behavior;
change the type hint patch_size: tuple to patch_size: tuple[int, ...]; replace
any use of x.view(*v) with x.reshape(*v) to avoid errors on non-contiguous
tensors; and add a deterministic unit test that directly instantiates and
exercises _PatchRearrange (independent of einops/Rearrange) to validate its
behavior for representative spatial_dims/patch_size combinations.

Comment on lines +134 to +138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fallback path isn't deterministically covered by tests.

_PatchRearrange only executes when Rearrange(..., **axes_len) raises TypeError, i.e. only on einops ≥ 0.8. Existing test_shape cases (tests/networks/blocks/test_patchembedding.py) exercise whichever branch the installed einops version selects, so CI never compares both paths in the same run. Suggest a targeted test that forces the fallback — e.g. monkey-patch Rearrange to raise TypeError, or instantiate _PatchRearrange directly and compare its output against the Rearrange path for a known input.

As per coding guidelines: "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/blocks/patchembedding.py` around lines 131 - 135, Tests don't
deterministically exercise the fallback _PatchRearrange because current tests
run whichever einops version is installed; add a unit test that forces the
fallback by monkey-patching the Rearrange symbol to raise TypeError (or by
directly instantiating _PatchRearrange) and assert that outputs of the fallback
match the normal Rearrange path for a representative input shape; target names
to modify/assert are _PatchRearrange, Rearrange and the patch embedding behavior
(e.g., the patch_embeddings sequence) in the existing
tests/networks/blocks/test_patchembedding.py so both branches are covered and
compared.

self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)

Expand Down
Loading