Skip to content
Open
4 changes: 3 additions & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ jobs:
python --version
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
env:
HF_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ -x --durations=0 -W ignore::DeprecationWarning -W ignore::UserWarning -W ignore::RuntimeWarning
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Filtering out these warnings makes CI logs much cleaner and easier to navigate for developers focusing on test results.

# add_pull_ready
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ici_tensor_parallelism: 1
allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
pretrained_model_name_or_path: ''
pretrained_model_name_or_path: 'Lightricks/LTX-Video'
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def create_key(seed=0):
def run(config):
rng = jax.random.PRNGKey(config.seed)

devices_array = max_utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

prompts = config.prompt
negative_prompts = config.negative_prompt
controlnet_conditioning_scale = config.controlnet_conditioning_scale
Expand All @@ -48,13 +51,14 @@ def run(config):
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
)
with mesh:
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
)

pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
)
pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
)

scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
Expand All @@ -68,21 +72,23 @@ def run(config):
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=config.num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
controlnet_conditioning_scale=controlnet_conditioning_scale,
jit=True,
).images

with mesh:
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=config.num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
controlnet_conditioning_scale=controlnet_conditioning_scale,
jit=True,
).images

output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
output_images[0].save("generated_image.png")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
"""

import os
import functools
from absl import app
from typing import Sequence, Union, List
from datasets import load_dataset
import numpy as np
import jax
from flax import nnx
import jax.numpy as jnp
from jax.sharding import Mesh
from maxdiffusion import pyconfig, max_utils
Expand Down Expand Up @@ -110,8 +110,9 @@ def generate_dataset(config, pipeline):
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)

# jit vae fun.
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
@nnx.jit
def p_vae_encode(video, rng, vae, vae_cache):
return vae_encode(video, rng, vae, vae_cache)

# Load dataset
ds = load_dataset(config.dataset_name, split="train")
Expand All @@ -126,7 +127,7 @@ def generate_dataset(config, pipeline):
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
with mesh:
latents = p_vae_encode(video=video, rng=new_rng)
latents = p_vae_encode(video=video, rng=new_rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
encoder_hidden_states = text_encode(pipeline, text)
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
writer.write(create_example(latent, encoder_hidden_state))
Expand Down
24 changes: 20 additions & 4 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,18 @@ def tokenize(prompt, pipeline):
return inputs


def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
def get_unet_inputs(pipeline, scheduler_params, states, config, rng, mesh, batch_size):
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))

vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Good use of sharding constraints to ensure consistent data placement and avoid unnecessary communication or re-sharding during the inference loop.

prompt_ids = [config.prompt] * batch_size
prompt_ids = tokenize(prompt_ids, pipeline)
prompt_ids = jax.lax.with_sharding_constraint(prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None)))
negative_prompt_ids = [config.negative_prompt] * batch_size
negative_prompt_ids = tokenize(negative_prompt_ids, pipeline)
negative_prompt_ids = jax.lax.with_sharding_constraint(
negative_prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None))
)
guidance_scale = config.guidance_scale
guidance_rescale = config.guidance_rescale
num_inference_steps = config.num_inference_steps
Expand All @@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
"text_encoder_2": states["text_encoder_2_state"].params,
}
prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params)
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
pooled_embeds = jax.lax.with_sharding_constraint(pooled_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))

batch_size = prompt_embeds.shape[0]
add_time_ids = get_add_time_ids(
Expand All @@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):

prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
add_text_embeds = jax.lax.with_sharding_constraint(add_text_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))

add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)

else:
Expand All @@ -166,8 +175,11 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):

latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32)

if isinstance(scheduler_params, dict) and "scheduler" in scheduler_params:
scheduler_params = scheduler_params["scheduler"]

scheduler_state = pipeline.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
)

latents = latents * scheduler_state.init_noise_sigma
Expand Down Expand Up @@ -217,9 +229,11 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
def run(config):
checkpoint_loader = GenerateSDXL(config)
mesh = checkpoint_loader.mesh
with mesh:
pipeline, params = checkpoint_loader.load_checkpoint()
# NOTE: load_checkpoint() is called outside the mesh context intentionally.
# If checkpoint loading requires mesh-aware sharding, move this back inside `with mesh:`.
pipeline, params = checkpoint_loader.load_checkpoint()

with mesh:
noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)

weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
Expand Down Expand Up @@ -303,11 +317,13 @@ def run(config):
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
print("compile time: ", (time.time() - s))

s = time.time()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
images = p_run_inference(states).block_until_ready()
print("inference time: ", (time.time() - s))

images = jax.experimental.multihost_utils.process_allgather(images, tiled=True)
numpy_images = np.array(images)
images = VaeImageProcessor.numpy_to_pil(numpy_images)
Expand Down
9 changes: 6 additions & 3 deletions src/maxdiffusion/tests/data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import os
import pytest
import functools
import jax
import jax.numpy as jnp
from flax import nnx
from flax.linen import partitioning as nn_partitioning
from jax.sharding import Mesh
from .. import pyconfig
Expand Down Expand Up @@ -81,11 +81,14 @@ def test_wan_vae_encode_normalization(self):
video = load_video(video_path)
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width)]
videos = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))

@nnx.jit
def p_vae_encode(video, rng, vae, vae_cache):
return vae_encode(video, rng, vae, vae_cache)

rng = jax.random.key(config.seed)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
latents = p_vae_encode(videos, rng=rng)
latents = p_vae_encode(videos, rng=rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
# 1. Verify Channel Count (Wan 2.1 requires 16)
self.assertEqual(latents.shape[1], 16, f"Expected 16 channels, got {latents.shape[1]}")

Expand Down
8 changes: 2 additions & 6 deletions src/maxdiffusion/tests/generate_ltx2_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ def setUpClass(cls):
)
cls.config = pyconfig.config
checkpoint_loader = LTX2Checkpointer(config=cls.config)
# Load pipeline without upsampler for simplicity in smoke test
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)

cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
cls.prompt = [cls.config.prompt]
cls.negative_prompt = [cls.config.negative_prompt]

def test_ltx2_inference(self):
"""Test that LTX2 pipeline can run inference and produce output."""
Expand Down Expand Up @@ -92,9 +91,6 @@ def test_ltx2_inference(self):
# Check that we got frames
self.assertGreater(len(videos), 0)

# LTX2 might also produce audio, check if it's there if expected
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
# We can just log if audio is present.
if audios is not None:
print(f"Audio produced with shape: {audios[0].shape}")
self.assertGreater(len(audios), 0)
Expand Down
18 changes: 16 additions & 2 deletions src/maxdiffusion/tests/generate_sdxl_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
class Generate(unittest.TestCase):
"""Smoke test."""

def tearDown(self):
super().tearDown()
import gc

gc.collect()
import jax
jax.clear_caches()

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_hyper_sdxl_lora(self):
img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png")
Expand All @@ -53,6 +61,7 @@ def test_hyper_sdxl_lora(self):
'diffusion_scheduler_config={"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}',
'lora_config={"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}',
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand Down Expand Up @@ -84,6 +93,7 @@ def test_sdxl_config(self):
"run_name=sdxl-inference-test",
"split_head_dim=False",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand Down Expand Up @@ -116,6 +126,7 @@ def test_sdxl_from_gcs(self):
"run_name=sdxl-inference-test",
"split_head_dim=False",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
Expand All @@ -139,14 +150,16 @@ def test_controlnet_sdxl(self):
"activations_dtype=bfloat16",
"weights_dtype=bfloat16",
f"jax_cache_dir={JAX_CACHE_DIR}",
"controlnet_image=" + os.path.join(THIS_DIR, "images", "cnet_test.png"),
"jit_initializers=False",
],
unittest=True,
)
images = generate_run_sdxl_controlnet(pyconfig.config)
test_image = np.array(images[0]).astype(np.uint8)
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70
assert ssim_compare >= 0.80

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_sdxl_lightning(self):
Expand All @@ -158,14 +171,15 @@ def test_sdxl_lightning(self):
os.path.join(THIS_DIR, "..", "configs", "base_xl_lightning.yml"),
"run_name=sdxl-lightning-test",
f"jax_cache_dir={JAX_CACHE_DIR}",
"jit_initializers=False",
],
unittest=True,
)
images = generate_run_xl(pyconfig.config)
test_image = np.array(images[0]).astype(np.uint8)
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
assert base_image.shape == test_image.shape
assert ssim_compare >= 0.70
assert ssim_compare >= 0.80


if __name__ == "__main__":
Expand Down
Loading
Loading