Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions photomaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
import os
import PIL
import numpy as np
import numpy as np

import torch
from torchvision import transforms as T
Expand Down Expand Up @@ -57,9 +57,9 @@ def load_photomaker_adapter(
The subfolder location of a model file within a larger model repository on the Hub or locally.

trigger_word (`str`, *optional*, defaults to `"img"`):
The trigger word is used to identify the position of class word in the text prompt,
and it is recommended not to set it as a common word.
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
The trigger word is used to identify the position of class word in the text prompt,
and it is recommended not to set it as a common word.
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
"""

# Load the main state dict first.
Expand Down Expand Up @@ -112,7 +112,7 @@ def load_photomaker_adapter(
print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...")
id_encoder = PhotoMakerIDEncoder()
id_encoder.load_state_dict(state_dict["id_encoder"], strict=True)
id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype)
id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype)
self.id_encoder = id_encoder
self.id_image_processor = CLIPImageProcessor()

Expand All @@ -121,11 +121,11 @@ def load_photomaker_adapter(
self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")

# Add trigger word token
if self.tokenizer is not None:
if self.tokenizer is not None:
self.tokenizer.add_tokens([self.trigger_word], special_tokens=True)

self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True)


def encode_prompt_with_trigger_word(
self,
Expand Down Expand Up @@ -182,8 +182,8 @@ def encode_prompt_with_trigger_word(
# Expand the class word token and corresponding mask
class_token = clean_input_ids[class_token_index]
clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \
clean_input_ids[class_token_index+1:]
clean_input_ids[class_token_index+1:]

# Truncation or padding
max_len = tokenizer.model_max_length
if len(clean_input_ids) > max_len:
Expand All @@ -195,10 +195,10 @@ def encode_prompt_with_trigger_word(

class_tokens_mask = [True if class_token_index <= i < class_token_index+num_id_images else False \
for i in range(len(clean_input_ids))]

clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0)
class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0)

prompt_embeds = text_encoder(
clean_input_ids.to(device),
output_hidden_states=True,
Expand Down Expand Up @@ -255,11 +255,11 @@ def __call__(
):
r"""
Function invoked when calling the pipeline for generation.
Only the parameters introduced by PhotoMaker are discussed here.
Only the parameters introduced by PhotoMaker are discussed here.
For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Args:
input_id_images (`PipelineImageInput`, *optional*):
input_id_images (`PipelineImageInput`, *optional*):
Input ID Image to work with PhotoMaker.
class_tokens_mask (`torch.LongTensor`, *optional*):
Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word.
Expand Down Expand Up @@ -296,7 +296,7 @@ def __call__(
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
#
#
if prompt_embeds is not None and class_tokens_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`."
Expand Down Expand Up @@ -328,7 +328,7 @@ def __call__(

# 3. Encode input prompt
num_id_images = len(input_id_images)

(
prompt_embeds,
pooled_prompt_embeds,
Expand All @@ -342,7 +342,7 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
class_tokens_mask=class_tokens_mask,
)

# 4. Encode input prompt without the trigger word for delayed conditioning
# encode, remove trigger word token, then decode
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
Expand Down Expand Up @@ -377,7 +377,7 @@ def __call__(

# 6. Get the update text embedding with the stacked ID embedding
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -494,4 +494,4 @@ def __call__(
if not return_dict:
return (image,)

return StableDiffusionXLPipelineOutput(images=image)
return StableDiffusionXLPipelineOutput(images=image)
12 changes: 10 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,17 @@ def predict(
num_steps: int = Input(
description="Number of sample steps", default=20, ge=1, le=100
),
width: int = Input(
description="Output image width", default=1024
),
height: int = Input(
description="Output image height", default=1024
),
style_strength_ratio: float = Input(
description="Style strength (%)", default=20, ge=15, le=50
),
num_outputs: int = Input(
description="Number of output images", default=1, ge=1, le=4
description="Number of output images", default=1, ge=1, le=10
),
guidance_scale: float = Input(
description="Guidance scale. A guidance scale of 1 corresponds to doing no classifier free guidance.", default=5, ge=1, le=10.0
Expand Down Expand Up @@ -202,9 +208,11 @@ def predict(
print(f"Start merge step: {start_merge_step}")
images = self.pipe(
prompt=prompt,
width=width,
height=height,
input_id_images=input_id_images,
negative_prompt=negative_prompt,
num_images_per_prompt=num_outputs,
num_images_per_prompt=num_outputs,
num_inference_steps=num_steps,
start_merge_step=start_merge_step,
generator=generator,
Expand Down