From 4772f4b1643abbbdcaf12ab90736f3c207c889a4 Mon Sep 17 00:00:00 2001 From: Denis Mysenko Date: Sat, 24 Feb 2024 09:24:09 +1100 Subject: [PATCH] + expose width and height in the form --- photomaker/pipeline.py | 38 +++++++++++++++++++------------------- predict.py | 12 ++++++++++-- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/photomaker/pipeline.py b/photomaker/pipeline.py index dc2fedc..1596f7e 100755 --- a/photomaker/pipeline.py +++ b/photomaker/pipeline.py @@ -2,7 +2,7 @@ from collections import OrderedDict import os import PIL -import numpy as np +import numpy as np import torch from torchvision import transforms as T @@ -57,9 +57,9 @@ def load_photomaker_adapter( The subfolder location of a model file within a larger model repository on the Hub or locally. trigger_word (`str`, *optional*, defaults to `"img"`): - The trigger word is used to identify the position of class word in the text prompt, - and it is recommended not to set it as a common word. - This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. + The trigger word is used to identify the position of class word in the text prompt, + and it is recommended not to set it as a common word. + This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. """ # Load the main state dict first. @@ -112,7 +112,7 @@ def load_photomaker_adapter( print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") id_encoder = PhotoMakerIDEncoder() id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) - id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) + id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) self.id_encoder = id_encoder self.id_image_processor = CLIPImageProcessor() @@ -121,11 +121,11 @@ def load_photomaker_adapter( self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") # Add trigger word token - if self.tokenizer is not None: + if self.tokenizer is not None: self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) - + self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) - + def encode_prompt_with_trigger_word( self, @@ -182,8 +182,8 @@ def encode_prompt_with_trigger_word( # Expand the class word token and corresponding mask class_token = clean_input_ids[class_token_index] clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \ - clean_input_ids[class_token_index+1:] - + clean_input_ids[class_token_index+1:] + # Truncation or padding max_len = tokenizer.model_max_length if len(clean_input_ids) > max_len: @@ -195,10 +195,10 @@ def encode_prompt_with_trigger_word( class_tokens_mask = [True if class_token_index <= i < class_token_index+num_id_images else False \ for i in range(len(clean_input_ids))] - + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) - + prompt_embeds = text_encoder( clean_input_ids.to(device), output_hidden_states=True, @@ -255,11 +255,11 @@ def __call__( ): r""" Function invoked when calling the pipeline for generation. - Only the parameters introduced by PhotoMaker are discussed here. + Only the parameters introduced by PhotoMaker are discussed here. For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Args: - input_id_images (`PipelineImageInput`, *optional*): + input_id_images (`PipelineImageInput`, *optional*): Input ID Image to work with PhotoMaker. class_tokens_mask (`torch.LongTensor`, *optional*): Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. @@ -296,7 +296,7 @@ def __call__( pooled_prompt_embeds, negative_pooled_prompt_embeds, ) - # + # if prompt_embeds is not None and class_tokens_mask is None: raise ValueError( "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." @@ -328,7 +328,7 @@ def __call__( # 3. Encode input prompt num_id_images = len(input_id_images) - + ( prompt_embeds, pooled_prompt_embeds, @@ -342,7 +342,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, class_tokens_mask=class_tokens_mask, ) - + # 4. Encode input prompt without the trigger word for delayed conditioning # encode, remove trigger word token, then decode tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) @@ -377,7 +377,7 @@ def __call__( # 6. Get the update text embedding with the stacked ID embedding prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) - + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -494,4 +494,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) diff --git a/predict.py b/predict.py index 3cef443..17991c1 100644 --- a/predict.py +++ b/predict.py @@ -133,11 +133,17 @@ def predict( num_steps: int = Input( description="Number of sample steps", default=20, ge=1, le=100 ), + width: int = Input( + description="Output image width", default=1024 + ), + height: int = Input( + description="Output image height", default=1024 + ), style_strength_ratio: float = Input( description="Style strength (%)", default=20, ge=15, le=50 ), num_outputs: int = Input( - description="Number of output images", default=1, ge=1, le=4 + description="Number of output images", default=1, ge=1, le=10 ), guidance_scale: float = Input( description="Guidance scale. A guidance scale of 1 corresponds to doing no classifier free guidance.", default=5, ge=1, le=10.0 @@ -202,9 +208,11 @@ def predict( print(f"Start merge step: {start_merge_step}") images = self.pipe( prompt=prompt, + width=width, + height=height, input_id_images=input_id_images, negative_prompt=negative_prompt, - num_images_per_prompt=num_outputs, + num_images_per_prompt=num_outputs, num_inference_steps=num_steps, start_merge_step=start_merge_step, generator=generator,