Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
hf_download
hf_download/
outputs/
repo/
Expand Down
78 changes: 60 additions & 18 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
from utils.lora_utils import merge_lora_to_state_dict
from utils.fp8_optimization_utils import optimize_state_dict_with_fp8, apply_fp8_monkey_patch


parser = argparse.ArgumentParser()
Expand All @@ -38,6 +39,7 @@
parser.add_argument("--port", type=int, required=False)
parser.add_argument("--inbrowser", action='store_true')
parser.add_argument("--output_dir", type=str, default='./outputs')
parser.add_argument("--offline", default=True)
args = parser.parse_args()

# for win desktop probably use --server 127.0.0.1 --inbrowser
Expand All @@ -54,19 +56,37 @@
print(f'Free VRAM {free_mem_gb} GB')
print(f'High-VRAM Mode: {high_vram}')

text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
if args.offline:
HF_CACHE_HUB_PATH = os.path.join(os.environ['HF_HOME'], 'hub')
HUNYUAN_VIDEO_LOCAL_PATH = os.path.join(HF_CACHE_HUB_PATH, 'models--hunyuanvideo-community--HunyuanVideo')
FLUX_REDUX_LOCAL_PATH = os.path.join(HF_CACHE_HUB_PATH, 'models--lllyasviel--flux_redux_bfl')
FRAMEPACK_I2V_HY_LOCAL_PATH = os.path.join(HF_CACHE_HUB_PATH, 'models--lllyasviel--FramePackI2V_HY')

feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
text_encoder = LlamaModel.from_pretrained(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots'))[0], 'text_encoder'), torch_dtype=torch.float16).cpu()
text_encoder_2 = CLIPTextModel.from_pretrained(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots'))[0], 'text_encoder_2'), torch_dtype=torch.float16).cpu()
tokenizer = LlamaTokenizerFast.from_pretrained(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots'))[0], 'tokenizer'))
tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots'))[0], 'tokenizer_2'))
vae = AutoencoderKLHunyuanVideo.from_pretrained(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(HUNYUAN_VIDEO_LOCAL_PATH, 'snapshots'))[0], 'vae'), torch_dtype=torch.float16).cpu()

feature_extractor = SiglipImageProcessor.from_pretrained(os.path.join(FLUX_REDUX_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(FLUX_REDUX_LOCAL_PATH, 'snapshots'))[0], 'feature_extractor'))
image_encoder = SiglipVisionModel.from_pretrained(os.path.join(FLUX_REDUX_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(FLUX_REDUX_LOCAL_PATH, 'snapshots'))[0], 'image_encoder'), torch_dtype=torch.float16).cpu()

transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(os.path.join(FRAMEPACK_I2V_HY_LOCAL_PATH, 'snapshots', os.listdir(os.path.join(FRAMEPACK_I2V_HY_LOCAL_PATH, 'snapshots'))[0]), torch_dtype=torch.bfloat16).cpu()
else:
text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()

feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()

transformer = None # load later
transformer_dtype = torch.bfloat16
previous_lora_file = None
previous_lora_multiplier = None
previous_fp8_optimization = None

vae.eval()
text_encoder.eval()
Expand Down Expand Up @@ -103,12 +123,13 @@


@torch.no_grad()
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier):
global transformer, previous_lora_file, previous_lora_multiplier
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization):
global transformer, previous_lora_file, previous_lora_multiplier, previous_fp8_optimization

model_changed = transformer is None or (
lora_file != previous_lora_file
or lora_multiplier != previous_lora_multiplier
or fp8_optimization != previous_fp8_optimization
)

total_latent_sections = (total_second_length * 24) / (latent_window_size * 4)
Expand Down Expand Up @@ -194,6 +215,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

previous_lora_file = lora_file
previous_lora_multiplier = lora_multiplier
previous_fp8_optimization = fp8_optimization

transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu()
transformer.eval()
Expand All @@ -203,13 +225,32 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
transformer.to(dtype=torch.bfloat16)
transformer.requires_grad_(False)

if lora_file is not None:
if lora_file is not None or fp8_optimization:
state_dict = transformer.state_dict()
print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
gc.collect()

# LoRA should be merged before fp8 optimization
if lora_file is not None:
# TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
# Use from_config() instead of from_pretrained to make the instance without loading.

print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
gc.collect()

if fp8_optimization:
TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8

# inplace optimization
print("Optimizing for fp8")
state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)

# apply monkey patching
apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
gc.collect()

info = transformer.load_state_dict(state_dict, strict=True, assign=True)
print(f"LoRA applied: {info}")
print(f"LoRA and/or fp8 optimization applied: {info}")

if not high_vram:
DynamicSwapInstaller.install_model(transformer, device=gpu)
Expand Down Expand Up @@ -353,15 +394,15 @@ def callback(d):
return


def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier):
def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization):
global stream
assert input_image is not None, 'No input image!'

yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)

stream = AsyncStream()

async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier)
async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization)

output_filename = None

Expand Down Expand Up @@ -423,13 +464,14 @@ def end_process():
rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change

# This is only used when high_vram is False
gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.", visible=not high_vram)
gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=0, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.", visible=not high_vram)

mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")

with gr.Group():
lora_file = gr.File(label="LoRA File", file_count="single", type="filepath")
lora_multiplier = gr.Slider(label="LoRA Multiplier", minimum=0.0, maximum=1.0, value=0.8, step=0.1)
fp8_optimization = gr.Checkbox(label="FP8 Optimization", value=True)

with gr.Column():
preview_image = gr.Image(label="Next Latents", height=200, visible=False)
Expand All @@ -440,7 +482,7 @@ def end_process():

gr.HTML('<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>')

ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier]
ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization]
start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
end_button.click(fn=end_process)

Expand Down
44 changes: 34 additions & 10 deletions demo_gradio_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
from utils.lora_utils import merge_lora_to_state_dict
from utils.fp8_optimization_utils import optimize_state_dict_with_fp8, apply_fp8_monkey_patch


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -68,6 +69,7 @@
transformer_dtype = torch.bfloat16
previous_lora_file = None
previous_lora_multiplier = None
previous_fp8_optimization = None

vae.eval()
text_encoder.eval()
Expand Down Expand Up @@ -104,12 +106,13 @@


@torch.no_grad()
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier):
global transformer, previous_lora_file, previous_lora_multiplier
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization):
global transformer, previous_lora_file, previous_lora_multiplier, previous_fp8_optimization

model_changed = transformer is None or (
lora_file != previous_lora_file
or lora_multiplier != previous_lora_multiplier
or fp8_optimization != previous_fp8_optimization
)

total_latent_sections = (total_second_length * 24) / (latent_window_size * 4)
Expand Down Expand Up @@ -195,6 +198,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

previous_lora_file = lora_file
previous_lora_multiplier = lora_multiplier
previous_fp8_optimization = fp8_optimization

transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePack_F1_I2V_HY_20250503', torch_dtype=torch.bfloat16).cpu()
transformer.eval()
Expand All @@ -204,13 +208,32 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
transformer.to(dtype=torch.bfloat16)
transformer.requires_grad_(False)

if lora_file is not None:
if lora_file is not None or fp8_optimization:
state_dict = transformer.state_dict()
print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
gc.collect()

# LoRA should be merged before fp8 optimization
if lora_file is not None:
# TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
# Use from_config() instead of from_pretrained to make the instance without loading.

print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
gc.collect()

if fp8_optimization:
TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8

# inplace optimization
print("Optimizing for fp8")
state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)

# apply monkey patching
apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
gc.collect()

info = transformer.load_state_dict(state_dict, strict=True, assign=True)
print(f"LoRA applied: {info}")
print(f"LoRA and/or fp8 optimization applied: {info}")

if not high_vram:
DynamicSwapInstaller.install_model(transformer, device=gpu)
Expand Down Expand Up @@ -341,15 +364,15 @@ def callback(d):
return


def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier):
def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization):
global stream
assert input_image is not None, 'No input image!'

yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)

stream = AsyncStream()

async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier)
async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization)

output_filename = None

Expand Down Expand Up @@ -418,6 +441,7 @@ def end_process():
with gr.Group():
lora_file = gr.File(label="LoRA File", file_count="single", type="filepath")
lora_multiplier = gr.Slider(label="LoRA Multiplier", minimum=0.0, maximum=1.0, value=0.8, step=0.1)
fp8_optimization = gr.Checkbox(label="FP8 Optimization", value=False)

with gr.Column():
preview_image = gr.Image(label="Next Latents", height=200, visible=False)
Expand All @@ -427,7 +451,7 @@ def end_process():

gr.HTML('<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>')

ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier]
ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, resolution, lora_file, lora_multiplier, fp8_optimization]
start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
end_button.click(fn=end_process)

Expand Down
Loading