Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: for models with standalone and deeply nested layers in block-le…
…vel offloading
  • Loading branch information
rycerzes committed Nov 21, 2025
commit 59b6b678295214b70f6ecaa3f95129b76baf50d8
298 changes: 298 additions & 0 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,146 @@ def post_forward(self, module, output):
return output


# Model simulating VAE structure with standalone computational layers
class DummyVAELikeModel(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()

# Encoder container (not ModuleList/Sequential at top level)
self.encoder = torch.nn.Sequential(
torch.nn.Linear(in_features, hidden_features),
torch.nn.ReLU(),
)

# Standalone Conv2d layer (simulates quant_conv)
self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)

# Decoder container with nested ModuleList
self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features)

# Standalone Conv2d layer (simulates post_quant_conv)
self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)

# Output projection
self.linear_out = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encode
x = self.encoder(x)

# Reshape for conv operations
batch_size = x.shape[0]
x_reshaped = x.view(batch_size, 1, -1, 1)

# Apply standalone conv layers
x_reshaped = self.quant_conv(x_reshaped)
x_reshaped = self.post_quant_conv(x_reshaped)

# Reshape back
x = x_reshaped.view(batch_size, -1)

# Decode
x = self.decoder(x)

# Output
x = self.linear_out(x)
return x


class DecoderWithNestedBlocks(torch.nn.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()

# Container modules (not ModuleList/Sequential)
self.conv_in = torch.nn.Linear(in_features, in_features)

# Nested ModuleList (like VAE's decoder.up_blocks)
self.up_blocks = torch.nn.ModuleList(
[torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)]
)

# Non-computational layer
self.norm = torch.nn.LayerNorm(in_features)

self.conv_out = torch.nn.Linear(in_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_in(x)
for block in self.up_blocks:
x = block(x)
x = self.norm(x)
x = self.conv_out(x)
return x


# Model with only standalone computational layers at top level
class DummyModelWithStandaloneLayers(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()

self.layer1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
self.layer3 = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
x = self.layer3(x)
return x


# Model with deeply nested structure
class DummyModelWithDeeplyNestedBlocks(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()

self.input_layer = torch.nn.Linear(in_features, hidden_features)
self.container = ContainerWithNestedModuleList(hidden_features)
self.output_layer = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_layer(x)
x = self.container(x)
x = self.output_layer(x)
return x


class ContainerWithNestedModuleList(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()

# Top-level computational layer
self.proj_in = torch.nn.Linear(features, features)

# Nested container with ModuleList
self.nested_container = NestedContainer(features)

# Another top-level computational layer
self.proj_out = torch.nn.Linear(features, features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.nested_container(x)
x = self.proj_out(x)
return x


class NestedContainer(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()

self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
self.norm = torch.nn.LayerNorm(features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x


@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
Expand Down Expand Up @@ -362,3 +502,161 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)

def test_vae_like_model_with_standalone_conv_layers(self):
"""Test that models with standalone Conv2d layers (like VAE) work with block-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.")

def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)

def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for model with standalone layers."
)

def test_model_with_deeply_nested_blocks(self):
"""Test models with deeply nested structure where ModuleList is not at top level."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for deeply nested model.")

@parameterized.expand([("block_level",), ("leaf_level",)])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
)

def test_multiple_invocations_with_vae_like_model(self):
"""Test that multiple forward passes work correctly with VAE-like model."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
for i in range(5):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")

def test_nested_container_parameters_offloading(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse an existing modeling class from the library without using DummyModelWithDeeplyNestedBlocks? Not a merge-blocker, though.

"""Test that parameters from non-computational layers in nested containers are handled correctly."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
for i in range(3):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
)
Loading