Skip to content

Instantly share code, notes, and snippets.

@rycerzes
Created November 21, 2025 09:53
Show Gist options
  • Select an option

  • Save rycerzes/681dbf5624505cafff1ac62c215ffb17 to your computer and use it in GitHub Desktop.

Select an option

Save rycerzes/681dbf5624505cafff1ac62c215ffb17 to your computer and use it in GitHub Desktop.
"""
Test script for group offloading with block_level for various models.
Tests both AutoencoderKL (SDXL) and AutoencoderKLWan to verify that
block-level group offloading works correctly with models that have
standalone encoder/decoder layers.
"""
import os
import sys
import traceback
from pathlib import Path
import torch
# Enable faster downloads with hf_transfer if available
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from diffusers import AutoencoderKL, AutoencoderKLWan, StableDiffusionXLPipeline
# Set cache directory to models/
MODELS_DIR = Path("./models")
MODELS_DIR.mkdir(exist_ok=True)
print("\n" + "#" * 80)
print("# Group Offloading Test Script")
print("# Testing block_level offloading with various models")
print("#" * 80)
print(f"\nModels will be cached to: {MODELS_DIR.absolute()}")
if not torch.cuda.is_available():
print("\nERROR: CUDA is not available. This test requires a CUDA device.")
sys.exit(1)
print(f"\nCUDA Device: {torch.cuda.get_device_name(0)}")
print(f"PyTorch Version: {torch.__version__}")
onload_device = torch.device("cuda:0")
offload_device = torch.device("cpu")
print(f"\nOnload device: {onload_device}")
print(f"Offload device: {offload_device}")
def test_sdxl_vae_block_level():
"""Test SDXL AutoencoderKL with block-level offloading"""
print("\n" + "=" * 80)
print("TEST 1: SDXL VAE with block-level offloading")
print("=" * 80)
try:
print("Loading SDXL VAE...")
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
torch_dtype=torch.bfloat16,
cache_dir=MODELS_DIR,
use_safetensors=True,
)
print("VAE loaded")
print("\nEnabling block-level group offloading...")
vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=1,
use_stream=False,
)
print("Group offloading enabled")
print("\nTesting decode...")
latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16)
with torch.no_grad():
decoded = vae.decode(latents).sample
print(f"Test passed - output shape: {decoded.shape}")
return True
except RuntimeError as e:
if "Input type" in str(e) and "weight type" in str(e):
print(f"Test failed - device mismatch: {e}")
traceback.print_exc()
return False
else:
print(f"Test failed - unexpected error: {e}")
traceback.print_exc()
raise
except Exception as e:
print(f"Test failed: {e}")
traceback.print_exc()
raise
def test_sdxl_vae_leaf_level():
"""Test SDXL AutoencoderKL with leaf-level offloading"""
print("\n" + "=" * 80)
print("TEST 2: SDXL VAE with leaf-level offloading")
print("=" * 80)
try:
print("Loading SDXL VAE...")
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
torch_dtype=torch.bfloat16,
cache_dir=MODELS_DIR,
use_safetensors=True,
)
print("VAE loaded")
print("\nEnabling leaf-level group offloading...")
vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=False,
)
print("Group offloading enabled")
print("\nTesting decode...")
latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16)
with torch.no_grad():
decoded = vae.decode(latents).sample
print(f"Test passed - output shape: {decoded.shape}")
return True
except Exception as e:
print(f"Test failed: {e}")
traceback.print_exc()
return False
def test_wan_vae_block_level():
"""Test AutoencoderKLWan with block-level offloading"""
print("\n" + "=" * 80)
print("TEST 3: WAN VAE with block-level offloading")
print("=" * 80)
try:
print("Loading AutoencoderKLWan...")
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
subfolder="vae",
torch_dtype=torch.float32,
cache_dir=MODELS_DIR,
)
print("VAE loaded")
print("\nAnalyzing structure...")
print(f" - encoder type: {type(vae.encoder).__name__}")
print(f" - decoder type: {type(vae.decoder).__name__}")
children_names = [name for name, _ in vae.named_children()]
print(f" - named_children: {children_names}")
print("\nEnabling block-level group offloading...")
vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=2,
use_stream=False,
)
print("Group offloading enabled")
print("\nTesting encode...")
test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32)
with torch.no_grad():
encoded = vae.encode(test_input)
print(f"Test passed - output device: {encoded.latent_dist.sample().device}")
return True
except Exception as e:
if "Wan-AI" in str(e) or "not found" in str(e).lower():
print(f"Model not available: {e}")
return None
elif "Input type" in str(e) and "weight type" in str(e):
print(f"Test failed - device mismatch: {e}")
traceback.print_exc()
return False
else:
print(f"Test failed: {e}")
traceback.print_exc()
raise
def test_wan_vae_leaf_level():
"""Test AutoencoderKLWan with leaf-level offloading"""
print("\n" + "=" * 80)
print("TEST 4: WAN VAE with leaf-level offloading")
print("=" * 80)
try:
print("Loading AutoencoderKLWan...")
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
subfolder="vae",
torch_dtype=torch.float32,
cache_dir=MODELS_DIR,
)
print("VAE loaded")
print("\nEnabling leaf-level group offloading...")
vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=False,
)
print("Group offloading enabled")
print("\nTesting encode...")
test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32)
with torch.no_grad():
encoded = vae.encode(test_input)
print(f"Test passed - output device: {encoded.latent_dist.sample().device}")
return True
except Exception as e:
if "Wan-AI" in str(e) or "not found" in str(e).lower():
print(f"Model not available: {e}")
return None
else:
print(f"Test failed: {e}")
traceback.print_exc()
return False
def test_sdxl_pipeline_block_level():
"""Test full SDXL pipeline with block-level offloading"""
print("\n" + "=" * 80)
print("TEST 5: SDXL Pipeline with block-level offloading")
print("=" * 80)
try:
print("Loading SDXL pipeline...")
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.bfloat16,
variant="fp16",
use_safetensors=True,
cache_dir=MODELS_DIR,
)
print("Pipeline loaded")
print("\nEnabling block-level group offloading...")
pipe.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=1,
use_stream=False,
)
print("Group offloading enabled")
print("\nGenerating image...")
_image = pipe(
prompt="A beautiful painting of a futuristic cityscape at sunset",
width=512,
height=512,
num_inference_steps=5,
generator=torch.Generator(device=onload_device).manual_seed(42),
).images[0]
print("Test passed")
return True
except RuntimeError as e:
if "should be the same" in str(e) or "device" in str(e).lower():
print(f"Test failed - device mismatch: {e}")
traceback.print_exc()
return False
else:
print(f"Test failed: {e}")
traceback.print_exc()
raise
except Exception as e:
print(f"Test failed: {e}")
traceback.print_exc()
raise
def main():
results = {}
# Test SDXL VAE
results["sdxl_vae_block_level"] = test_sdxl_vae_block_level()
results["sdxl_vae_leaf_level"] = test_sdxl_vae_leaf_level()
# Test WAN VAE
wan_block_result = test_wan_vae_block_level()
if wan_block_result is not None:
results["wan_vae_block_level"] = wan_block_result
wan_leaf_result = test_wan_vae_leaf_level()
if wan_leaf_result is not None:
results["wan_vae_leaf_level"] = wan_leaf_result
# Test full pipeline if requested
if os.environ.get("RUN_FULL_TESTS", "0") == "1":
results["sdxl_pipeline_block_level"] = test_sdxl_pipeline_block_level()
# Summary
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
for test_name, passed in results.items():
status = "PASSED" if passed else "FAILED"
print(f"{test_name:35s}: {status}")
print("=" * 80)
# Return exit code based on results
failed = [name for name, passed in results.items() if not passed]
if failed:
print(f"\nFailed tests: {', '.join(failed)}")
sys.exit(1)
else:
print("\nAll tests passed")
sys.exit(0)
if __name__ == "__main__":
main()
@rycerzes
Copy link
Author

Standalone script output for commit 59b6b678295214b70f6ecaa3f95129b76baf50d8

# Group Offloading Test Script
# Testing block_level offloading with various models
################################################################################

Models will be cached to: D:\Github\oss\diffusers\models

CUDA Device: NVIDIA GeForce RTX 4070 Laptop GPU
PyTorch Version: 2.9.1+cu128

Onload device: cuda:0
Offload device: cpu

================================================================================
TEST 1: SDXL VAE with block-level offloading
================================================================================
Loading SDXL VAE...
VAE loaded

Enabling block-level group offloading...
Group offloading enabled

Testing decode...
Test passed - output shape: torch.Size([1, 3, 512, 512])

================================================================================
TEST 2: SDXL VAE with leaf-level offloading
================================================================================
Loading SDXL VAE...
VAE loaded

Enabling leaf-level group offloading...
Group offloading enabled

Testing decode...
Test passed - output shape: torch.Size([1, 3, 512, 512])

================================================================================
TEST 3: WAN VAE with block-level offloading
================================================================================
Loading AutoencoderKLWan...
VAE loaded

Analyzing structure...
  - encoder type: WanEncoder3d
  - decoder type: WanDecoder3d
  - named_children: ['encoder', 'quant_conv', 'post_quant_conv', 'decoder']

Enabling block-level group offloading...
Group offloading enabled

Testing encode...
Test passed - output device: cuda:0

================================================================================
TEST 4: WAN VAE with leaf-level offloading
================================================================================
Loading AutoencoderKLWan...
VAE loaded

Enabling leaf-level group offloading...
Group offloading enabled

Testing encode...
Test passed - output device: cuda:0

================================================================================
SUMMARY
================================================================================
sdxl_vae_block_level               : PASSED
sdxl_vae_leaf_level                : PASSED
wan_vae_block_level                : PASSED
wan_vae_leaf_level                 : PASSED
================================================================================

All tests passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment