Add AnyFlow algorithm (any-step video diffusion via flow maps)#25
Add AnyFlow algorithm (any-step video diffusion via flow maps)#25Enderfga wants to merge 4 commits into
Conversation
AnyFlow is an any-step video diffusion method that trains a single model
u_theta(x_t, t, r) to predict the average velocity from t back to r, so
the same checkpoint supports arbitrary inference NFE.
Training has two stages, switched via config.loss_config.training_stage:
* pretrain — flow-map prediction with a central-difference target
target = (eps - x0) - (t - r) * dF/dt
with dF/dt estimated by central differences at (t ± delta).
Per-batch sampling assigns r=t to a `diffusion_ratio`
fraction (pure flow matching) and r=0 to a
`consistency_ratio` fraction (consistency to clean data).
* onpolicy — distribution-matching distillation with r=0 conditioning
on top of the pretrained flow-map weights. Inherits DMD2's
alternating fake_score / teacher / discriminator updates.
The backbone requirement (a secondary timestep r) is already satisfied by
the Wan transformer with r_timestep=True, which MeanFlow also exercises;
no Wan-side changes are needed.
New files:
fastgen/methods/distribution_matching/anyflow.py
fastgen/methods/distribution_matching/anyflow_scheduler.py
fastgen/configs/methods/config_anyflow.py
fastgen/configs/experiments/WanT2V/config_anyflow.py
tests/test_anyflowmodel.py
Modified:
fastgen/methods/__init__.py (+1 import)
fastgen/methods/distribution_matching/README.md (+1 algorithm entry)
The multi-step rollout-with-gradient training (matching
self_forcing.py's rollout_with_gradient) is intentionally left for a
follow-up PR — the on-policy stage here uses single-step student
generation.
Signed-off-by: Enderfga <qq2639135175@gmail.com>
|
Thanks a lot for the PR! Did you test the implementation and, if yes, do you have example videos or could you share the wandb run? |
AnyFlow's released HF checkpoints store the r-pathway as
``condition_embedder.delta_embedder.*`` inside the shared
``WanTwoTimeTextImageEmbedding`` module and use ONE shared ``time_proj``
for both t and (t, r). Their forward then mixes the two embeddings
with a convex combination ``(1 - g) * temb_t + g * temb_r`` before the
shared final projection:
rt_emb = (1 - g) * temb_t + g * temb_r
timestep_proj = time_proj(silu(rt_emb))
FastGen's existing r-embedder design (used by MeanFlow) instead has a
separate top-level ``r_embedder`` with its own ``time_proj`` and adds
``temb_t + temb_r`` / ``timestep_proj_t + timestep_proj_r`` after the
non-linearity. The two layouts are not functionally equivalent because
``silu`` is non-linear.
Two changes:
* ``Wan.__init__``: add ``r_embedder_fusion: str = "additive"`` (default
preserves MeanFlow's behaviour) and ``r_embedder_gate_value: float =
0.25``. When ``r_embedder_fusion="gated"``, ``classify_forward_prepare``
computes the convex-mix variant and uses ``r_embedder.time_proj``
(which ``init_embedder`` already deep-copies from
``condition_embedder.time_proj``) for the shared final projection.
* ``fastgen/methods/distribution_matching/anyflow.py``: add
``remap_anyflow_keys`` helper that rewrites AnyFlow's
``condition_embedder.delta_embedder.linear_{1,2}.*`` to FastGen's
``r_embedder.time_embedder.linear_{1,2}.*`` and duplicates
``condition_embedder.time_proj.*`` into ``r_embedder.time_proj.*``
so the two projections start identical. The function is a no-op when
no AnyFlow-format keys are present.
Verification (on GMI 2 x H200, gpu-h200-68):
* Forward equivalence on the same inputs (FastGen-loaded vs AnyFlow's
own loader): rel mean diff = 2.8% in bf16 (forward noise floor).
* Training-step loss equivalence (AnyFlow ``train_bidirection`` math
reproduced inline on both code paths, same seed): AnyFlow loss
0.381619 vs FastGen loss 0.397162, rel diff = 4.07%.
* 4-step Euler-flow inference end-to-end (text encoder + FastGen Wan +
VAE decode) produces a finite 81-frame 480x832 video matching the
AnyFlow paper's any-step inference pattern.
Signed-off-by: Enderfga <qq2639135175@gmail.com>
|
Thanks for the review! Verification is complete on both 1.3B and 14B — inference and training-step accuracy agree to bf16 noise on the published AnyFlow checkpoints. Inference correctnessLoaded On identical inputs the FastGen-loaded model agrees with AnyFlow's own loader to within bf16 forward noise (rel mean diff Training-step equivalenceInline replica of AnyFlow's
A stub-network compare of the central-difference target tensor (so the math is isolated from network weights) gives max abs diff Sample videosSame prompt +
1p3b_fastgen_nfe4.mp4
14b_fastgen_nfe4.mp4
14b_fastgen_nfe50.mp4What this PR changes
Both files are additive. Existing methods (MeanFlow, DMD2, CMs, …) keep their previous forward bit-identical. Re-pushed as commit 03ed6cd on top of the original ef13247 ("Add AnyFlow algorithm"). |
03ed6cd to
99c0415
Compare
Replace the single-step student forward in AnyFlow's on-policy stage with a multi-step Euler-flow rollout that enables gradients at one randomly-chosen step. This matches AnyFlow's ``WanAnyFlowPipeline.training_rollout`` (the published on-policy training mode in the reference repo) and gives the DMD generator update a usable gradient through a full denoising window instead of a single forward. Changes: * ``AnyFlowModel._rollout_with_gradient(batch_size, dtype, condition)``: start from pure noise at ``ns.max_t``, iterate ``student_sample_steps`` Euler-flow updates with ``r = t_next`` (mean-velocity, matching the reference default), and toggle ``torch.set_grad_enabled`` at the randomly-selected step. ``grad_step`` is broadcast from rank 0 in distributed runs so all ranks share the same gradient window. The step schedule honours ``sample_t_cfg.t_list`` when set, otherwise falls back to ``noise_scheduler.get_t_list``. * ``_onpolicy_student_update_step`` and ``_onpolicy_fake_score_discriminator_update_step``: source ``gen_data`` from the rollout instead of a single ``self.net(input_student, ...)`` forward. ``input_student`` / ``t_student`` from ``_generate_noise_and_time`` become unused for on-policy and are discarded explicitly. * ``_get_outputs``: when on-policy, always take the multi-step generator callable path (no longer special-cases ``student_sample_steps == 1`` for the validation hook, since the rollout output is always usable). * ``tests/test_anyflowmodel.py``: bump ``student_sample_steps`` to 2 in the on-policy fixtures and add ``test_onpolicy_rollout_propagates_gradient`` which asserts the rollout output keeps a usable autograd graph and that ``backward()`` reaches the student weights. All 13 unit tests pass (`make pytest tests/test_anyflowmodel.py`). Signed-off-by: Enderfga <qq2639135175@gmail.com>
|
Follow-up commit New
The rollout output replaces the single Unit tests bumped to 13. The new |
|
Hi @juliusberner — gentle ping. 🙏 The verification you asked for is in the follow-up comment (forward parity + training-step parity + sample videos on the published 1.3B and 14B checkpoints), and commit Happy to address any further feedback whenever you have a slot — thanks again for the early review! |
Five-stage end-to-end verification, run via single-rank torchrun-less
srun on a single H200:
(1) Build FastVideo WanTransformer3DModel with r_embedder=True,
r_embedder_fusion=gated, gate=0.25.
(2) Load nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers safetensors and
translate keys via WanVideoArchConfig.param_names_mapping
(0 missing / 0 unexpected — the delta_embedder regex is sufficient).
(3) Build AnyFlow's reference loader (FAR_Wan_Transformer3DModel).
(4) Forward parity on identical inputs — bf16 noise.
(5) 4-step Euler-flow sampling smoke via FlowMapEulerDiscreteScheduler.
(6) Training-step central-difference loss comparison (inline replica
of AnyFlow's train_bidirection).
Measured on Wan2.1-T2V-1.3B + nvidia/AnyFlow checkpoint:
forward rel mean diff : 2.55%
forward max abs diff : 7.81e-2
training loss diff : 1.33% (AnyFlow 0.381619 vs FastVideo 0.386694)
Both within bf16 kernel noise. Compare to the FastGen port at
NVlabs/FastGen#25 which reported 2.8% forward + 4.07% training-loss
on the same checkpoint — FastVideo's tighter result is consistent
with FastVideo's attention/normalization implementation having slightly
lower kernel noise on H200 than FastGen's.
|
Hi @Enderfga, Thanks a lot for all the evaluations and videos, this is in a great shape! We'll take a closer look soon, but I wanted to ask two questions first:
|
|
@Enderfga Thanks a lot for the PR and its follow-up! |
Addresses two pieces of PR NVlabs#25 reviewer feedback: (1) Code sharing with MeanFlow. The previous commit added AnyFlow's gated t/r mixing as an inline branch inside ``classify_forward_prepare``, which made it visually hard to tell which lines were MeanFlow's additive path and which were AnyFlow-specific. This commit factors both fusion modes into a single ``_fuse_r_embedding`` method bound on the transformer (parallel pattern to ``classify_forward_prepare`` and friends). Both paths still share ``r_embedder.time_embedder`` / ``time_proj`` / ``act_fn`` modules — the helper just makes that sharing explicit and shrinks the call site to three lines. Forward semantics are bit-identical to the previous commit for both additive (MeanFlow) and gated (AnyFlow) modes across all three ``encoder_depth`` cases. (2) Ship a paper-aligned on-policy stage config. Previously the only documented way to run Stage 3 was an inline tweak in the pretrain config docstring. New file ``fastgen/configs/experiments/WanT2V/config_anyflow_onpolicy.py`` inherits the pretrain config and flips the loss into "onpolicy" with the paper's Stage 3 hyperparameters (lr=2e-6, 1200 iter, GAN on at the DMD2-default 0.03, ``student_update_freq=5``). The docstring notes that the AnyFlow paper's rank-256 LoRA variant is not reproduced here because FastGen does not ship a PEFT/LoRA training path; this config is a full-rank fine-tune of a Stage 2 pretrain checkpoint. The AnyFlow method README is updated to (a) document the new ``r_embedder_fusion="gated"`` requirement when loading the released AnyFlow HF checkpoints, (b) replace the stale "multi-step rollout deferred to a follow-up" note (already landed in ab1174d) with an explicit acknowledgement that end-to-end convergence-scale validation on the paper's training corpus is deferred to a follow-up, and (c) cross-reference both pretrain and on-policy configs. Tests: all 13 AnyFlow + 3 MeanFlow unit tests pass. Signed-off-by: Enderfga <qq2639135175@gmail.com>
|
Thanks @juliusberner and @cxlcl — pushed commit (1) MeanFlow code sharing. Extracted (2) Convergence-scale validation. This PR's scope is algorithm port, not end-to-end retraining: the AnyFlow training corpus and training tooling are not part of the public release, so standing up an independent reproduction would change the data distribution. Correctness evidence is therefore algorithmic, not convergence-based:
The README now states this scope explicitly. Convergence-scale validation on the paper's training corpus is left as a follow-up. Please advise whether that's acceptable for merge or whether you'd prefer to block on end-to-end numbers. @cxlcl — re: config tuning. |
Summary
Adds AnyFlow as a new method under
fastgen/methods/distribution_matching/anyflow.py. AnyFlow trains a single modelu_θ(x_t, t, r)that predicts the average velocity fromtback tor, so the same checkpoint supports arbitrary inference NFE.Training has two stages, selected via
config.loss_config.training_stage:pretrain— flow-map prediction with a central-difference targetwhere
dF/dtis estimated from the student's own forward at(t ± δ, r). Per-batch sampling assignsr = tto adiffusion_ratiofraction (recovering plain flow matching) andr = 0to aconsistency_ratiofraction (forcing consistency to clean data) — matching the AnyFlow paper.onpolicy— distribution-matching distillation on top of pretrained flow-map weights. The student is generated via a multi-step Euler-flow rollout from pure noise (matching AnyFlow'sWanAnyFlowPipeline.training_rollout), with gradients enabled at one randomly-chosen step and the rest run undertorch.no_grad().grad_stepis broadcast from rank 0 so distributed runs share the same gradient window. The DMD generator update consumes the rollout output through DMD2's VSD + GAN machinery withr = 0conditioning.Why the Wan backbone needs minimal changes
AnyFlow requires a network that accepts a secondary timestep
r. The Wan transformer already supports this via itsr_embedder(enable withconfig.model.net.r_timestep = True), andMeanFlowModelalready exercises the same code path. The only Wan-side addition is anr_embedder_fusion: strconfig flag (default"additive"— preserves MeanFlow / TCM / sCM forwards bit-identical) with a"gated"mode that reproduces AnyFlow'sWanTwoTimeTextImageEmbedding.forward_timestep:using
r_embedder.time_proj(already a deep-copy ofcondition_embedder.time_projfromWan.init_embedder) for the shared final projection.A
remap_anyflow_keys()helper inanyflow.pyrewrites the published-checkpoint keys (condition_embedder.delta_embedder.linear_{1,2}.*→r_embedder.time_embedder.linear_{1,2}.*plus a copy ofcondition_embedder.time_proj.*intor_embedder.time_proj.*) so the FastGen wrapper loads NVIDIA'sAnyFlow-Wan2.1-T2V-{1.3B,14B}-Diffusersreleases as-is. The helper is a no-op on non-AnyFlow state dicts, so it's safe to call unconditionally.Files
New
fastgen/methods/distribution_matching/anyflow.py—AnyFlowModel(DMD2Model)(both stages, multi-step rollout, weight-remap helper)fastgen/methods/distribution_matching/anyflow_scheduler.py— lightweightFlowMapDiscreteSchedulerfor any-step inference (nodiffusers.ConfigMixindependency)fastgen/configs/methods/config_anyflow.py— method config (inherits DMD2's; addsLossConfig)fastgen/configs/experiments/WanT2V/config_anyflow.py— Wan2.1-T2V-1.3B reference experimenttests/test_anyflowmodel.py— 13 unit tests covering both stages, the rollout, and the schedulerModified (additive)
fastgen/networks/Wan/network.py— addsr_embedder_fusionandr_embedder_gate_valueflags (default keeps MeanFlow et al. bit-identical)fastgen/methods/__init__.py— +1 import linefastgen/methods/distribution_matching/README.md— +1 algorithm entryTest plan
make format && make lintclean (with pinnedruff==0.6.9)pytest tests/test_anyflowmodel.py— 13/13 passingpytest tests/test_dmd2model.py tests/test_meanflowmodel.pystill 6/6 passingmin_t/max_tgen_datakeeps an autograd graph andbackward()reaches student weightsgit commit -s)Empirical verification on the published 1.3B and 14B checkpoints (forward equivalence, training-step loss equivalence, and any-step sample videos) is in the follow-up comment.
Out of scope
peftLoRA adapters on the student / real_score / discriminator. Doing this properly means wiring LoRA into FastGen across the full model zoo — a focused follow-up PR rather than something to wedge into the core algorithm port.