diff --git a/huggingface_mae.py b/huggingface_mae.py index 3bbefae..1f697b1 100644 --- a/huggingface_mae.py +++ b/huggingface_mae.py @@ -219,7 +219,10 @@ def compute_MAE_loss( if not self.mask_fourier_loss: floss = floss.mean() else: - floss = floss.mean(dim=-1) + if floss.dim() == 4: + floss = floss.mean(dim=(-2, -1)) + else: + floss = floss.mean(dim=-1) floss = (floss * mask).sum() / mask.sum() loss_dict[self.FOURIER_LOSS] = floss.item()