Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: support for block-level offloading in group offloading config
  • Loading branch information
rycerzes committed Nov 24, 2025
commit fa94f37f441de7494b0fc726644fa67ef358b90d
241 changes: 83 additions & 158 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None


class ModuleGroup:
Expand All @@ -77,7 +78,7 @@ def __init__(
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand Down Expand Up @@ -453,6 +454,7 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -510,6 +512,9 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules
will be considered for block-level offloading. If not provided, the default block detection logic will be used.

Example:
```python
Expand Down Expand Up @@ -561,6 +566,7 @@ def apply_group_offloading(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
)
_apply_group_offloading(module, config)

Expand All @@ -576,84 +582,67 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf

def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading
is done at the top-level blocks and modules specified in block_modules.

Standalone computational layers (Conv2d, Linear, etc.) that are not part of ModuleList/Sequential are treated
individually with leaf-level logic to ensure proper device management. This includes computational layers nested
within container modules.
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, we either offload the entire submodule or recursively apply block offloading to it.
"""

if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1

# Create module groups for ModuleList and Sequential blocks
block_modules = set(config.block_modules) if config.block_modules is not None else set()

# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
unmatched_computational_layers = []
matched_module_groups = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Check if this is a computational layer that should be handled individually
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
unmatched_computational_layers.append((name, submodule))
modules_with_group_offloading.add(name)
else:
# This is a container module - recursively find computational layers within it
_find_and_apply_computational_layer_hooks(submodule, name, config, modules_with_group_offloading)
unmatched_modules.append((name, submodule))
# Do NOT add the container name to modules_with_group_offloading here, because we need
# parameters from non-computational sublayers (like GroupNorm) to be gathered
continue

for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules:
# Apply block offloading to the specified submodule
_apply_block_offloading_to_submodule(
submodule, name, config, modules_with_group_offloading, matched_module_groups
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Apply leaf-level treatment to standalone computational layers at the top level
# Each computational layer gets its own ModuleGroup with hooks registered directly on it
for name, comp_layer in unmatched_computational_layers:
group = ModuleGroup(
modules=[comp_layer],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=comp_layer,
onload_leader=comp_layer,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(comp_layer, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
Expand All @@ -662,7 +651,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the remaining unmatched submodules (non-computational containers) of the top-level
# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
Expand All @@ -687,129 +676,65 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _find_and_apply_computational_layer_hooks(
container_module: torch.nn.Module,
container_name: str,
def _apply_block_offloading_to_submodule(
submodule: torch.nn.Module,
name: str,
config: GroupOffloadingConfig,
modules_with_group_offloading: Set[str],
matched_module_groups: List[ModuleGroup],
) -> None:
r"""
Recursively finds all computational layers within a container module and applies individual hooks to them.
This ensures that standalone Conv2d, Linear, etc. layers nested inside container modules (like Encoder/Decoder)
get proper device management.
Apply block offloading to a explicitly defined submodule. This function either:
1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
2. Recursively applies block offloading to the submodule

For now, we use the simple approach - offload the entire submodule as a single group.
"""
for name, submodule in container_module.named_modules():
if name == "": # Skip the container itself
continue
# Simple approach: offload the entire submodule as a single group
# Since AEs are typically small, this is usually okay
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# If it's a ModuleList or Sequential, apply the normal block-level logic
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

# Only apply hooks to supported computational layers
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
full_name = f"{container_name}.{name}"
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=[submodule],
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=full_name,
group_id=group_id,
)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(full_name)

# Also handle parameters and buffers at non-leaf levels within the container
# This is similar to what leaf-level offloading does
module_dict = dict(container_module.named_modules())
parameters = []
buffers = []

for name, param in container_module.named_parameters():
# Check if this parameter has a parent that already got a hook
has_parent_with_hook = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
full_parent_name = f"{container_name}.{parent_name}"
if full_parent_name in modules_with_group_offloading:
has_parent_with_hook = True
break
atoms.pop()

if not has_parent_with_hook:
parameters.append((name, param))

for name, buffer in container_module.named_buffers():
# Check if this buffer has a parent that already got a hook
has_parent_with_hook = False
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
full_parent_name = f"{container_name}.{parent_name}"
if full_parent_name in modules_with_group_offloading:
has_parent_with_hook = True
break
atoms.pop()

if not has_parent_with_hook:
buffers.append((name, buffer))

# Group parameters and buffers by their immediate parent module and apply hooks
parent_to_parameters = {}
for name, param in parameters:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in module_dict:
if parent_name in parent_to_parameters:
parent_to_parameters[parent_name].append(param)
else:
parent_to_parameters[parent_name] = [param]
break
atoms.pop()

parent_to_buffers = {}
for name, buffer in buffers:
atoms = name.split(".")
while len(atoms) > 0:
parent_name = ".".join(atoms)
if parent_name in module_dict:
if parent_name in parent_to_buffers:
parent_to_buffers[parent_name].append(buffer)
else:
parent_to_buffers[parent_name] = [buffer]
break
atoms.pop()

parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
for name in parent_names:
params = parent_to_parameters.get(name, [])
bufs = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
full_parent_name = f"{container_name}.{name}"

matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# For other modules, treat the entire submodule as a single group
group = ModuleGroup(
modules=[],
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=config.offload_to_disk_path,
parameters=params,
buffers=bufs,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=full_parent_name,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, config=config)
modules_with_group_offloading.add(full_parent_name)
matched_module_groups.append(group)
modules_with_group_offloading.add(name)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def enable_group_offload(
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)

# Get block modules from the model if available
block_modules = getattr(self, "_group_offload_block_modules", None)

apply_group_offloading(
module=self,
onload_device=onload_device,
Expand All @@ -581,6 +585,7 @@ def enable_group_offload(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
)

def set_attention_backend(self, backend: str) -> None:
Expand Down