Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: update group offloading tests to use AutoencoderKL and adjust in…
…put dimensions
  • Loading branch information
rycerzes committed Nov 24, 2025
commit e71d91edd8b4784f977a92d92eb001f85da38bce
144 changes: 35 additions & 109 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from parameterized import parameterized

from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -149,78 +150,6 @@ def post_forward(self, module, output):
return output


# Model simulating VAE structure with standalone computational layers
class DummyVAELikeModel(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()

# Encoder container (not ModuleList/Sequential at top level)
self.encoder = torch.nn.Sequential(
torch.nn.Linear(in_features, hidden_features),
torch.nn.ReLU(),
)

# Standalone Conv2d layer (simulates quant_conv)
self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)

# Decoder container with nested ModuleList
self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features)

# Standalone Conv2d layer (simulates post_quant_conv)
self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)

# Output projection
self.linear_out = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encode
x = self.encoder(x)

# Reshape for conv operations
batch_size = x.shape[0]
x_reshaped = x.view(batch_size, 1, -1, 1)

# Apply standalone conv layers
x_reshaped = self.quant_conv(x_reshaped)
x_reshaped = self.post_quant_conv(x_reshaped)

# Reshape back
x = x_reshaped.view(batch_size, -1)

# Decode
x = self.decoder(x)

# Output
x = self.linear_out(x)
return x


class DecoderWithNestedBlocks(torch.nn.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()

# Container modules (not ModuleList/Sequential)
self.conv_in = torch.nn.Linear(in_features, in_features)

# Nested ModuleList (like VAE's decoder.up_blocks)
self.up_blocks = torch.nn.ModuleList(
[torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)]
)

# Non-computational layer
self.norm = torch.nn.LayerNorm(in_features)

self.conv_out = torch.nn.Linear(in_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
for block in self.up_blocks:
x = block(x)
x = self.norm(x)
x = self.conv_out(x)
return x


# Model with only standalone computational layers at top level
class DummyModelWithStandaloneLayers(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
Expand Down Expand Up @@ -503,45 +432,25 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)

def test_vae_like_model_with_standalone_conv_layers(self):
"""Test that models with standalone Conv2d layers (like VAE) work with block-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.")

def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)

x = torch.randn(2, 64).to(torch_device)
x = torch.randn(2, 3, 32, 32).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)
out_ref = model_ref(x).sample
out = model(x).sample

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
Expand Down Expand Up @@ -597,19 +506,20 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)
x = torch.randn(2, 3, 32, 32).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)
out_ref = model_ref(x).sample
out = model(x).sample

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
Expand All @@ -621,20 +531,21 @@ def test_multiple_invocations_with_vae_like_model(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)
x = torch.randn(2, 3, 32, 32).to(torch_device)

with torch.no_grad():
for i in range(5):
out_ref = model_ref(x)
out = model(x)
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")

def test_nested_container_parameters_offloading(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse an existing modeling class from the library without using DummyModelWithDeeplyNestedBlocks? Not a merge-blocker, though.

Expand All @@ -660,3 +571,18 @@ def test_nested_container_parameters_offloading(self):
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
)

def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"layers_per_block": 1,
}
return init_dict