Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 127 additions & 44 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None


class ModuleGroup:
Expand All @@ -77,7 +78,7 @@ def __init__(
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand Down Expand Up @@ -453,6 +454,7 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -510,6 +512,9 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules
will be considered for block-level offloading. If not provided, the default block detection logic will be used.

Example:
```python
Expand Down Expand Up @@ -561,6 +566,7 @@ def apply_group_offloading(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
)
_apply_group_offloading(module, config)

Expand All @@ -576,28 +582,123 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf

def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading
is done at the top-level blocks and modules specified in block_modules.

When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, we either offload the entire submodule or recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1

# Create module groups for ModuleList and Sequential blocks
block_modules = set(config.block_modules) if config.block_modules is not None else set()

# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []

for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Check if this is an explicitly defined block module
if name in block_modules:
# Apply block offloading to the specified submodule
_apply_block_offloading_to_submodule(
submodule, name, config, modules_with_group_offloading, matched_module_groups
)
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)
continue

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_block_offloading_to_submodule(
submodule: torch.nn.Module,
name: str,
config: GroupOffloadingConfig,
modules_with_group_offloading: Set[str],
matched_module_groups: List[ModuleGroup],
) -> None:
r"""
Apply block offloading to a explicitly defined submodule. This function either:
1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
2. Recursively applies block offloading to the submodule

For now, we use the simple approach - offload the entire submodule as a single group.
"""
# Simple approach: offload the entire submodule as a single group
# Since AEs are typically small, this is usually okay
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# If it's a ModuleList or Sequential, apply the normal block-level logic
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
Expand All @@ -616,42 +717,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
# For other modules, treat the entire submodule as a single group
group = ModuleGroup(
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
matched_module_groups.append(group)
modules_with_group_offloading.add(name)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]

@register_to_config
def __init__(
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def enable_group_offload(
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)

# Get block modules from the model if available
block_modules = getattr(self, "_group_offload_block_modules", None)

apply_group_offloading(
module=self,
onload_device=onload_device,
Expand All @@ -581,6 +585,7 @@ def enable_group_offload(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
)

def set_attention_backend(self, backend: str) -> None:
Expand Down
Loading