Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/visualization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.datasets import VisionDataset


def _hide_axis(ax: Axes, hide_ticks: bool = True, hide_ticklabels: bool = True) -> None:
def hide_axis(ax: Axes, hide_ticks: bool = True, hide_ticklabels: bool = True) -> None:
xaxis = ax.get_xaxis()
yaxis = ax.get_yaxis()
if hide_ticks:
Expand Down Expand Up @@ -41,7 +41,7 @@ def _display_sample(sample: Tensor, ax_idx: int, sample_label: str) -> None:
)
plt.imshow(sample.detach().cpu().numpy())
plt.gray()
_hide_axis(ax)
hide_axis(ax)

# For each column
for sample_idx in range(num_samples):
Expand Down
33 changes: 31 additions & 2 deletions tutorials/mnist-autoencoders.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,41 @@
"plt.imshow(sample.detach().cpu().numpy()) # plot the resulting image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### We can also observe the latent space's continuity. Pick two vectors in the latent space (`z_start` and `z_end`). We then show decoded images along the \\[`z_start`, `z_end`\\] segment. Observe how the digits gradually morph from one to the other."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from src.visualization.utils import hide_axis\n",
"\n",
"# Pick the two endpoints by changing the vectors' value\n",
"z_start = torch.tensor([-1, -1], dtype=torch.float).cuda()\n",
"z_end = torch.tensor([-1, 1], dtype=torch.float).cuda()\n",
"\n",
"# Build the interpolated points\n",
"n_steps = 11\n",
"alphas = np.linspace(0, 1, n_steps).round(8) # Round alphas to fix numerical instabilities in `np.linspace`\n",
"interpolations = {alpha: z_start * (1 - alpha) + z_end * alpha for alpha in alphas}\n",
"\n",
"# Plot the decoded images\n",
"fig, axes = plt.subplots(1, n_steps, figsize=(2 * n_steps, 2)) # Build the grid for plotting images\n",
"for ax, (alpha, z) in zip(axes, interpolations.items()):\n",
" sample = vae_decoder(z).reshape(data_shape) # Decode the interpolated latent vectors\n",
" ax.imshow(sample.detach().cpu().numpy())\n",
" hide_axis(ax)\n",
" ax.set_xlabel(f\"alpha={alpha}\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
Expand All @@ -665,7 +694,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading