Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/diffusers/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from loguru import logger
from diffusers.runtime_state import get_runtime_state

class CacheEntry:
def __init__(
self,
cache_type: "str",
num_cache_tensors: int = 1,
tensors: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
):
self.cache_type: str = cache_type
if tensors is None:
self.tensors: List[torch.Tensor] = [None,] * num_cache_tensors
elif isinstance(tensors, torch.Tensor):
self.tensors = [tensors, ]
elif isinstance(tensors, List):
self.tensors = [tensors, ]

class CacheManager:
def __init__(self):
self.cache: Dict[Tuple[str, Any], CacheEntry] = {}

def register_cache_entry(self, layer, layer_type: str, cache_type: str = "naive_cache"):
self.cache[layer_type, layer] = CacheEntry(cache_type)

def cache_update(
self,
new_kv: Union[torch.Tensor, List[torch.Tensor]],
layer,
slice_dim: int = 1,
layer_type: str = "attn",
):
return_list = False
if isinstance(new_kv, List):
return_list = True
new_kv = torch.cat(new_kv, dim=-1)
if get_runtime_state().num_pipeline_patch == 1 or not get_runtime_state().patch_mode:
kv_cache = new_kv
self.cache[layer_type, layer].tensors[0] = kv_cache
else:
start_token_idx = get_runtime_state().pp_patches_token_start_idx_local[
get_runtime_state().pipeline_patch_idx
]
end_token_idx = get_runtime_state().pp_patches_token_start_idx_local[
get_runtime_state().pipeline_patch_idx + 1
]
kv_cache = self.cache[layer_type, layer].tensors[0]
kv_cache = self._update_kv_in_dim(
kv_cache=kv_cache,
new_kv=new_kv,
dim=slice_dim,
start_idx=start_token_idx,
end_idx=end_token_idx,
)
self.cache[layer_type, layer].tensors[0] = kv_cache
if return_list:
return torch.chunk(kv_cache, 2, dim=-1)
else:
return kv_cache

def _update_kv_in_dim(
self,
kv_cache: torch.Tensor,
new_kv: torch.Tensor,
dim: int,
start_idx: int,
end_idx: int,
):
if dim < 0:
dim += kv_cache.dim()

if dim == 0:
kv_cache[start_idx:end_idx, ...] = new_kv
elif dim == 1:
kv_cache[:, start_idx:end_idx:, ...] = new_kv
elif dim == 2:
kv_cache[:, :, start_idx:end_idx, ...] = new_kv
elif dim == 3:
kv_cache[:, :, :, start_idx:end_idx, ...] = new_kv
return kv_cache

_CACHE_MANAGER = CacheManager()

def get_cache_manager():
global _CACHE_MANAGER
if _CACHE_MANAGER is None:
_CACHE_MANAGER = CacheManager()
return _CACHE_MANAGER
91 changes: 91 additions & 0 deletions src/diffusers/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from torch import nn
from torch.nn import functional as F
from diffusers.runtime_state import get_runtime_state
from diffusers.parallel_state import get_pipeline_parallel_world_size, get_sequence_parallel_world_size
from loguru import logger

class CustomConv2d(nn.Module):
def __init__(
self, conv2d: nn.Conv2d
):
super().__init__()
self.module = conv2d
self.module_type = type(self.module)
self.activation_cache = None

def naive_forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.module(x)
return output

def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
stride = self.module.stride[0]
padding = self.module.padding[0]

idx = get_runtime_state().pipeline_patch_idx
pp_patches_start_idx_local = get_runtime_state().pp_patches_start_idx_local
h_begin = pp_patches_start_idx_local[idx] - padding
h_end = pp_patches_start_idx_local[idx + 1] + padding
final_padding = [padding, padding, 0, 0]
if h_begin < 0:
h_begin = 0
final_padding[2] = padding
if h_end > h:
h_end = h
final_padding[3] = padding
sliced_input = x[:, :, h_begin:h_end, :]
padded_input = F.pad(sliced_input, final_padding, mode="constant")
result = F.conv2d(
padded_input,
self.module.weight,
self.module.bias,
stride=stride,
padding="valid",
)
return result

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if (
(
get_pipeline_parallel_world_size() == 1
and get_sequence_parallel_world_size() == 1
)
or self.module.kernel_size == (1, 1)
or self.module.kernel_size == 1
):
output = self.naive_forward(x)
else:
if (
not get_runtime_state().patch_mode
or get_runtime_state().num_pipeline_patch == 1
):
self.activation_cache = x
output = self.naive_forward(self.activation_cache)
else:
if self.activation_cache is None:
self.activation_cache = torch.zeros(
[
x.shape[0],
x.shape[1],
get_runtime_state().pp_patches_start_idx_local[-1],
x.shape[3],
],
dtype=x.dtype,
device=x.device,
)

self.activation_cache[
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx + 1
],
:,
] = x
output = self.sliced_forward(self.activation_cache)
return output
48 changes: 48 additions & 0 deletions src/diffusers/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
from diffusers.runtime_state import get_runtime_state
from torch import nn
from loguru import logger

class CustomPatchEmbed(nn.Module): # xinze: the difference is that, we do the positional embedding as if the patch is the full picture. After embedding process, we crop the result according to the patch index and use it as the final embedding.
def __init__(
self, patch_embedding: PatchEmbed,
):
super().__init__()
self.module = patch_embedding
self.module_type = type(self.module) # self.module.pos_embed is injected in the from_pretrained step.
self.pos_embed = None
self.activation_cache = None


def forward(self, latent):

height = (
get_runtime_state().config.height
// get_runtime_state().vae_scale_factor
)
width = latent.shape[-1]

latent = self.module.proj(latent)
if self.module.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC


# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)

if getattr(self.module, "pos_embed_max_size", None):
pos_embed = self.module.cropped_pos_embed(height, width)


if get_runtime_state().patch_mode:
start, end = get_runtime_state().pp_patches_token_start_end_idx_global[
get_runtime_state().pipeline_patch_idx
]
pos_embed = pos_embed[
:,
start:end,
:,
]

return (latent + pos_embed).to(latent.dtype)
Loading