Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix for failing tests
  • Loading branch information
DN6 committed Dec 5, 2025
commit cf65ae34493fbe685c0ace548204705eee3a37d9
46 changes: 37 additions & 9 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -60,6 +60,8 @@ class GroupOffloadingConfig:
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""


class ModuleGroup:
Expand Down Expand Up @@ -321,7 +323,20 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which causes issues with inplace updates. Use `exclude_kwargs` to mark these cache features
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
self.group.onload_device,
non_blocking=self.group.non_blocking,
)
kwargs.update(moved_kwargs)
else:
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

return args, kwargs

Expand Down Expand Up @@ -456,6 +471,7 @@ def apply_group_offloading(
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
exclude_kwargs: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -516,6 +532,10 @@ def apply_group_offloading(
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules will
be considered for block-level offloading. If not provided, the default block detection logic will be used.
exclude_kwargs (`List[str]`, *optional*):
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
caching lists that need to maintain their object identity across forward passes. If not provided, will be
inferred from the module's `_group_offload_exclude_kwargs` attribute if it exists.

Example:
```python
Expand Down Expand Up @@ -557,6 +577,12 @@ def apply_group_offloading(

_raise_error_if_accelerate_model_or_sequential_hook_present(module)

if block_modules is None:
block_modules = getattr(module, "_group_offload_block_modules", None)

if exclude_kwargs is None:
exclude_kwargs = getattr(module, "_group_offload_exclude_kwargs", None)

config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
Expand All @@ -568,6 +594,7 @@ def apply_group_offloading(
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)
_apply_group_offloading(module, config)

Expand Down Expand Up @@ -606,8 +633,11 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules:
# Apply block offloading to the specified submodule
_apply_group_offloading_block_level(submodule, config)
# track submodule using a prefix
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
submodule_config = replace(config, module_prefix=prefix)

_apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name)

elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
Expand All @@ -617,7 +647,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
Expand Down Expand Up @@ -655,9 +685,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
has_unmatched = len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0

if has_unmatched or len(block_modules) > 0:
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
Expand All @@ -671,7 +699,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
Expand Down
12 changes: 6 additions & 6 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1
else:
x = self.conv_out(x)

return x


Expand Down Expand Up @@ -961,11 +962,13 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""

_supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]

# kwargs to ignore when send_to_device moves inputs/outputs between devices
# these are shared mutable states that are modified in-place and
# should not be subjected to copy operations
_group_offload_exclude_kwargs = ["feat_cache", "feat_idx"]

@register_to_config
def __init__(
self,
Expand Down Expand Up @@ -1146,9 +1149,6 @@ def _encode(self, x: torch.Tensor):
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
__import__("ipdb").set_trace()
# cache_devices = [i.device.type for i in self._enc_feat_map]
# any((d != "cuda" for d in cache_devices))

enc = self.quant_conv(out)
self.clear_cache()
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def enable_group_offload(
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[str] = None,
exclude_kwargs: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
Expand Down Expand Up @@ -571,9 +573,6 @@ def enable_group_offload(
f"open an issue at https://github.com/huggingface/diffusers/issues."
)

# Get block modules from the model if available
block_modules = getattr(self, "_group_offload_block_modules", None)

apply_group_offloading(
module=self,
onload_device=onload_device,
Expand All @@ -586,6 +585,7 @@ def enable_group_offload(
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)

def set_attention_backend(self, backend: str) -> None:
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,6 @@ def run_forward(model):
return model(**inputs_dict)[0]

model = self.model_class(**init_dict)

model.to(torch_device)
output_without_group_offloading = run_forward(model)

Expand Down
25 changes: 18 additions & 7 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,7 @@ def _get_expected_safetensors_files(
offload_type: str,
num_blocks_per_group: Optional[int] = None,
block_modules: Optional[List[str]] = None,
module_prefix: str = "",
) -> Set[str]:
expected_files = set()

Expand All @@ -1439,30 +1440,40 @@ def get_hashed_filename(group_id: str) -> str:
block_modules_set = set(block_modules) if block_modules is not None else set()

# Handle groups of ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules_set:
# Recursively get expected files for the specified submodule
# Recursively get expected files for the specified submodule with updated prefix
new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
submodule_files = _get_expected_safetensors_files(
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
)
expected_files.update(submodule_files)
modules_with_group_offloading.add(name)

elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append(module)
unmatched_modules.append(submodule)

# Handle the group for unmatched top-level modules and parameters/buffers
# We need to check if there are any parameters/buffers that don't belong to modules with group offloading
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

# Handle the group for unmatched top-level modules and parameters
if len(unmatched_modules) > 0:
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))

elif offload_type == "leaf_level":
# Handle leaf-level module groups
Expand Down