Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into fix/broken-group-offloading-using-block_level
fix: reduce repeats in group offload tests
  • Loading branch information
rycerzes committed Nov 24, 2025
commit 09dd19bb7d9776ebc314e8c0a54e8a882ba8e7be
38 changes: 9 additions & 29 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")

num_repeats = 4
num_repeats = 2
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
Expand Down Expand Up @@ -472,33 +472,13 @@ def test_model_with_only_standalone_layers(self):
x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for model with standalone layers."
)

def test_model_with_deeply_nested_blocks(self):
"""Test models with deeply nested structure where ModuleList is not at top level."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return

model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)

model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)

model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
out_ref = model_ref(x)
out = model(x)

self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for deeply nested model.")
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
)

@parameterized.expand([("block_level",), ("leaf_level",)])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
Expand Down Expand Up @@ -564,7 +544,7 @@ def test_nested_container_parameters_offloading(self):
x = torch.randn(2, 64).to(torch_device)

with torch.no_grad():
for i in range(3):
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.