Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: streamline block offloading logic
  • Loading branch information
rycerzes committed Dec 3, 2025
commit e7711435bd5197cdb61e836f32ddca2bd8b5c12d
66 changes: 2 additions & 64 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,8 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Check if this is an explicitly defined block module
if name in block_modules:
# Apply block offloading to the specified submodule
_apply_block_offloading_to_submodule(
submodule, name, config, modules_with_group_offloading, matched_module_groups
)
_apply_group_offloading_block_level(submodule, config)
modules_with_group_offloading.add(name)
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
Expand Down Expand Up @@ -676,67 +675,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_block_offloading_to_submodule(
submodule: torch.nn.Module,
name: str,
config: GroupOffloadingConfig,
modules_with_group_offloading: Set[str],
matched_module_groups: List[ModuleGroup],
) -> None:
r"""
Apply block offloading to a explicitly defined submodule. This function either:
1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
2. Recursively applies block offloading to the submodule

For now, we use the simple approach - offload the entire submodule as a single group.
"""
# Simple approach: offload the entire submodule as a single group
# Since AEs are typically small, this is usually okay
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# If it's a ModuleList or Sequential, apply the normal block-level logic
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# For other modules, treat the entire submodule as a single group
group = ModuleGroup(
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
matched_module_groups.append(group)
modules_with_group_offloading.add(name)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
Expand Down
Loading