-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix broken group offloading with block_level for models with standalone layers #12692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DN6
merged 14 commits into
huggingface:main
from
rycerzes:fix/broken-group-offloading-using-block_level
Dec 5, 2025
Merged
Changes from 1 commit
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
ad1fc37
fix: group offloading to support standalone computational layers in b…
rycerzes 59b6b67
test: for models with standalone and deeply nested layers in block-le…
rycerzes fa94f37
feat: support for block-level offloading in group offloading config
rycerzes fb8a741
fix: group offload block modules to AutoencoderKL and AutoencoderKLWan
rycerzes e71d91e
fix: update group offloading tests to use AutoencoderKL and adjust in…
rycerzes 09dd19b
Merge branch 'main' into fix/broken-group-offloading-using-block_level
rycerzes e771143
refactor: streamline block offloading logic
rycerzes ab9b249
Apply style fixes
github-actions[bot] 26bccde
update tests
DN6 f305934
update
DN6 cf65ae3
fix for failing tests
DN6 09a7b0a
clean up
DN6 c888aac
revert to use skip_keys
DN6 4bd3384
clean up
DN6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix: update group offloading tests to use AutoencoderKL and adjust in…
…put dimensions
- Loading branch information
commit e71d91edd8b4784f977a92d92eb001f85da38bce
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from diffusers import AutoencoderKL | ||
| from diffusers.hooks import HookRegistry, ModelHook | ||
| from diffusers.models import ModelMixin | ||
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | ||
|
|
@@ -149,78 +150,6 @@ def post_forward(self, module, output): | |
| return output | ||
|
|
||
|
|
||
| # Model simulating VAE structure with standalone computational layers | ||
| class DummyVAELikeModel(ModelMixin): | ||
| def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: | ||
| super().__init__() | ||
|
|
||
| # Encoder container (not ModuleList/Sequential at top level) | ||
| self.encoder = torch.nn.Sequential( | ||
| torch.nn.Linear(in_features, hidden_features), | ||
| torch.nn.ReLU(), | ||
| ) | ||
|
|
||
| # Standalone Conv2d layer (simulates quant_conv) | ||
| self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) | ||
|
|
||
| # Decoder container with nested ModuleList | ||
| self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features) | ||
|
|
||
| # Standalone Conv2d layer (simulates post_quant_conv) | ||
| self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1) | ||
|
|
||
| # Output projection | ||
| self.linear_out = torch.nn.Linear(hidden_features, out_features) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| # Encode | ||
| x = self.encoder(x) | ||
|
|
||
| # Reshape for conv operations | ||
| batch_size = x.shape[0] | ||
| x_reshaped = x.view(batch_size, 1, -1, 1) | ||
|
|
||
| # Apply standalone conv layers | ||
| x_reshaped = self.quant_conv(x_reshaped) | ||
| x_reshaped = self.post_quant_conv(x_reshaped) | ||
|
|
||
| # Reshape back | ||
| x = x_reshaped.view(batch_size, -1) | ||
|
|
||
| # Decode | ||
| x = self.decoder(x) | ||
|
|
||
| # Output | ||
| x = self.linear_out(x) | ||
| return x | ||
|
|
||
|
|
||
| class DecoderWithNestedBlocks(torch.nn.Module): | ||
| def __init__(self, in_features: int, out_features: int) -> None: | ||
| super().__init__() | ||
|
|
||
| # Container modules (not ModuleList/Sequential) | ||
| self.conv_in = torch.nn.Linear(in_features, in_features) | ||
|
|
||
| # Nested ModuleList (like VAE's decoder.up_blocks) | ||
| self.up_blocks = torch.nn.ModuleList( | ||
| [torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)] | ||
| ) | ||
|
|
||
| # Non-computational layer | ||
| self.norm = torch.nn.LayerNorm(in_features) | ||
|
|
||
| self.conv_out = torch.nn.Linear(in_features, out_features) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x = self.conv_in(x) | ||
| for block in self.up_blocks: | ||
| x = block(x) | ||
| x = self.norm(x) | ||
| x = self.conv_out(x) | ||
| return x | ||
|
|
||
|
|
||
| # Model with only standalone computational layers at top level | ||
| class DummyModelWithStandaloneLayers(ModelMixin): | ||
| def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: | ||
|
|
@@ -503,45 +432,25 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): | |
| cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" | ||
| ) | ||
|
|
||
| def test_vae_like_model_with_standalone_conv_layers(self): | ||
| """Test that models with standalone Conv2d layers (like VAE) work with block-level offloading.""" | ||
| if torch.device(torch_device).type not in ["cuda", "xpu"]: | ||
| return | ||
|
|
||
| model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
|
|
||
| model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| model_ref.load_state_dict(model.state_dict(), strict=True) | ||
| model_ref.to(torch_device) | ||
|
|
||
| model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) | ||
|
|
||
| x = torch.randn(2, 64).to(torch_device) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model_ref(x) | ||
| out = model(x) | ||
|
|
||
| self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.") | ||
|
|
||
| def test_vae_like_model_without_streams(self): | ||
| """Test VAE-like model with block-level offloading but without streams.""" | ||
| if torch.device(torch_device).type not in ["cuda", "xpu"]: | ||
| return | ||
|
|
||
| model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| config = self.get_autoencoder_kl_config() | ||
| model = AutoencoderKL(**config) | ||
|
|
||
| model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| model_ref = AutoencoderKL(**config) | ||
| model_ref.load_state_dict(model.state_dict(), strict=True) | ||
| model_ref.to(torch_device) | ||
|
|
||
| model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) | ||
|
|
||
| x = torch.randn(2, 64).to(torch_device) | ||
| x = torch.randn(2, 3, 32, 32).to(torch_device) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model_ref(x) | ||
| out = model(x) | ||
| out_ref = model_ref(x).sample | ||
| out = model(x).sample | ||
|
|
||
| self.assertTrue( | ||
| torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." | ||
|
|
@@ -597,19 +506,20 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str) | |
| if torch.device(torch_device).type not in ["cuda", "xpu"]: | ||
| return | ||
|
|
||
| model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| config = self.get_autoencoder_kl_config() | ||
| model = AutoencoderKL(**config) | ||
|
|
||
| model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| model_ref = AutoencoderKL(**config) | ||
| model_ref.load_state_dict(model.state_dict(), strict=True) | ||
| model_ref.to(torch_device) | ||
|
|
||
| model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) | ||
|
|
||
| x = torch.randn(2, 64).to(torch_device) | ||
| x = torch.randn(2, 3, 32, 32).to(torch_device) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model_ref(x) | ||
| out = model(x) | ||
| out_ref = model_ref(x).sample | ||
| out = model(x).sample | ||
|
|
||
| self.assertTrue( | ||
| torch.allclose(out_ref, out, atol=1e-5), | ||
|
|
@@ -621,20 +531,21 @@ def test_multiple_invocations_with_vae_like_model(self): | |
| if torch.device(torch_device).type not in ["cuda", "xpu"]: | ||
| return | ||
|
|
||
| model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| config = self.get_autoencoder_kl_config() | ||
| model = AutoencoderKL(**config) | ||
|
|
||
| model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64) | ||
| model_ref = AutoencoderKL(**config) | ||
| model_ref.load_state_dict(model.state_dict(), strict=True) | ||
| model_ref.to(torch_device) | ||
|
|
||
| model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) | ||
|
|
||
| x = torch.randn(2, 64).to(torch_device) | ||
| x = torch.randn(2, 3, 32, 32).to(torch_device) | ||
|
|
||
| with torch.no_grad(): | ||
| for i in range(5): | ||
| out_ref = model_ref(x) | ||
| out = model(x) | ||
| for i in range(2): | ||
| out_ref = model_ref(x).sample | ||
| out = model(x).sample | ||
| self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") | ||
|
|
||
| def test_nested_container_parameters_offloading(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we reuse an existing modeling class from the library without using |
||
|
|
@@ -660,3 +571,18 @@ def test_nested_container_parameters_offloading(self): | |
| torch.allclose(out_ref, out, atol=1e-5), | ||
| f"Outputs do not match at iteration {i} for nested parameters.", | ||
| ) | ||
|
|
||
| def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): | ||
| block_out_channels = block_out_channels or [2, 4] | ||
| norm_num_groups = norm_num_groups or 2 | ||
| init_dict = { | ||
| "block_out_channels": block_out_channels, | ||
| "in_channels": 3, | ||
| "out_channels": 3, | ||
| "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), | ||
| "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), | ||
| "latent_channels": 4, | ||
| "norm_num_groups": norm_num_groups, | ||
| "layers_per_block": 1, | ||
| } | ||
| return init_dict | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.