Skip to content

Add tiny transformer LLM notebook#2163

Draft
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:cetagostini/llm_example_branch
Draft

Add tiny transformer LLM notebook#2163
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:cetagostini/llm_example_branch

Conversation

@cetagostini
Copy link
Copy Markdown
Contributor

@cetagostini cetagostini commented May 22, 2026

Summary

  • Add a tiny decoder-only transformer LLM gallery notebook using PyTensor and xtensor named dimensions.
  • Lower matmul-shaped xtensor dot contractions to matmul and keep outer products on the existing einsum fallback.
  • Prefer static shape constants in tensordot reshape shapes and add focused regression coverage.

Test plan

  • conda run -n pytensor-dev python -m ruff check pytensor/tensor/math.py pytensor/xtensor/rewriting/math.py tests/tensor/test_math.py tests/xtensor/test_math.py
  • `conda run -n pytensor-dev python -m pytest tests/xtensor/test_math.py::test_dot
  • tests/xtensor/test_math.py::test_dot_lowers_to_matmul
  • tests/xtensor/test_math.py::test_dot_outer_product_falls_back_to_einsum
  • tests/xtensor/test_math.py::test_dot_errors tests/xtensor/test_math.py::test_dot_vectorize
  • tests/tensor/test_math.py::TestTensordot -q`

Add a new gallery notebook demonstrating a tiny decoder-only transformer LLM implemented with pytensor/xtensor (doc/gallery/transformers/tiny_transformer_llm.ipynb). Update .gitignore to exclude AI tool artifacts, gallery downloaded data, and JupyterLab session files. Also apply related updates to math implementation and rewrites (pytensor/tensor/math.py, pytensor/xtensor/rewriting/math.py) and adjust tests (tests/tensor/test_math.py, tests/xtensor/test_math.py) to match the changes.
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cetagostini cetagostini marked this pull request as draft May 22, 2026 07:50
@cetagostini cetagostini self-assigned this May 22, 2026
@cetagostini
Copy link
Copy Markdown
Contributor Author

Doing the notebook found a few things which were adding an overhead in xtensor, adjust and end-up being even faster now xtensor than plain tensor.

Comment on lines +20 to +21
constant when possible, instead of a chain of ``Mul`` nodes over individual
``ScalarConstant``s.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? pytensor will rewrite away the mul constants, this is just eager stuff? We don't want to eagerly use static shapes in actual inputs

@@ -0,0 +1,1246 @@
{
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #1.    from pytensor.xtensor.shape import stack as xstack

import pytensor.xtensor as ptx, and then use ptx.stack and the like


Reply via ReviewNB

@@ -0,0 +1,1246 @@
{
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #32.        scores = px.dot(q, k, dim="hd") / scale          # (batch, head, time_q, time_k)

you can do assert scores.dims == ("batch", "head", "time_q", "time_k"), to self document the dims instead of as a comment, also teaches these are always around for introspection


Reply via ReviewNB

@@ -0,0 +1,1246 @@
{
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #4.    def gen_step(context, rng):

you can work with xtensor variables still, just convert to tensor before going into the scan, convert to xtensor inside the scan, convert to tensor before returning from scan, and convert the scan outputs to xtensor outside as soon as you get them. Basically handle the boundary.

Also you could make a while scan that runs until the termination token is emitted


Reply via ReviewNB

@ricardoV94
Copy link
Copy Markdown
Member

This is nice, I don't want the random xtensor changes, we need to investigate why it was not simplifying in your case, may be another symptom of #2056 or something else, but shouldn't be done in a docs PR

@cetagostini
Copy link
Copy Markdown
Contributor Author

@ricardoV94 follow some of your comments, and came up with this: #2164

@@ -0,0 +1,1246 @@
{
Copy link
Copy Markdown
Member

@twiecki twiecki May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool!


Reply via ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants