diff --git a/nodes.py b/nodes.py index 9d63102..b24685f 100644 --- a/nodes.py +++ b/nodes.py @@ -289,6 +289,8 @@ def INPUT_TYPES(s): }, "optional": { "start_latent": ("LATENT", {"tooltip": "init Latents to use for image2video"} ), + "image_end_embeds": ("CLIP_VISION_OUTPUT", {"tooltip": "end image embeds to use for image2video"}), + "end_latent": ("LATENT", {"tooltip": "end Latents to use for image2video"} ), "initial_samples": ("LATENT", {"tooltip": "init Latents to use for video2video"} ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } @@ -299,8 +301,8 @@ def INPUT_TYPES(s): FUNCTION = "process" CATEGORY = "FramePackWrapper" - def process(self, model, shift, positive, negative, latent_window_size, use_teacache, total_second_length, teacache_rel_l1_thresh, image_embeds, steps, cfg, - guidance_scale, seed, sampler, gpu_memory_preservation, start_latent=None, initial_samples=None, denoise_strength=1.0): + def process(self, model, shift, positive, negative, latent_window_size, use_teacache, total_second_length, teacache_rel_l1_thresh, image_embeds, steps, cfg, + guidance_scale, seed, sampler, gpu_memory_preservation, start_latent=None, image_end_embeds=None, end_latent=None, initial_samples=None, denoise_strength=1.0): total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) print("total_latent_sections: ", total_latent_sections) @@ -323,6 +325,11 @@ def process(self, model, shift, positive, negative, latent_window_size, use_teac image_encoder_last_hidden_state = image_embeds["last_hidden_state"].to(base_dtype).to(device) + if end_latent is not None and image_end_embeds is not None: + end_latent = end_latent["samples"] * vae_scaling_factor + image_end_encoder_last_hidden_state = image_end_embeds["last_hidden_state"].to(base_dtype).to(device) + image_encoder_last_hidden_state = (image_encoder_last_hidden_state + image_end_encoder_last_hidden_state) / 2 + llama_vec = positive[0][0].to(base_dtype).to(device) clip_l_pooler = positive[0][1]["pooled_output"].to(base_dtype).to(device) @@ -373,9 +380,11 @@ def process(self, model, shift, positive, negative, latent_window_size, use_teac for latent_padding in latent_paddings: print(f"latent_padding: {latent_padding}") is_last_section = latent_padding == 0 + is_first_section = latent_padding == latent_paddings[0] + latent_padding_size = latent_padding * latent_window_size - print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}') + print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}') indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) @@ -385,6 +394,11 @@ def process(self, model, shift, positive, negative, latent_window_size, use_teac clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + # Use end image latent for the first section if provided + if end_latent is not None and image_end_embeds is not None and is_first_section: + clean_latents_post = end_latent.to(history_latents) + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + #vid2vid if initial_samples is not None: