Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clean up
  • Loading branch information
DN6 committed Dec 5, 2025
commit 09a7b0ad25bb49017040dd1061e328a76f133b5b
6 changes: 4 additions & 2 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):

# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which causes issues with inplace updates. Use `exclude_kwargs` to mark these cache features
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
Expand Down Expand Up @@ -633,7 +633,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules:
# track submodule using a prefix
# Track submodule using a prefix to avoid filename collisions during disk offload.
# Without this, submodules sharing the same model class would be assigned identical
# filenames (derived from the class name).
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
submodule_config = replace(config, module_prefix=prefix)

Expand Down
7 changes: 0 additions & 7 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,13 +1439,10 @@ def get_hashed_filename(group_id: str) -> str:

block_modules_set = set(block_modules) if block_modules is not None else set()

# Handle groups of ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules_set:
# Recursively get expected files for the specified submodule with updated prefix
new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
submodule_files = _get_expected_safetensors_files(
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
Expand All @@ -1454,7 +1451,6 @@ def get_hashed_filename(group_id: str) -> str:
modules_with_group_offloading.add(name)

elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
Expand All @@ -1464,11 +1460,8 @@ def get_hashed_filename(group_id: str) -> str:
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append(submodule)

# Handle the group for unmatched top-level modules and parameters/buffers
# We need to check if there are any parameters/buffers that don't belong to modules with group offloading
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

Expand Down