Created
November 21, 2025 09:53
-
-
Save rycerzes/681dbf5624505cafff1ac62c215ffb17 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Test script for group offloading with block_level for various models. | |
| Tests both AutoencoderKL (SDXL) and AutoencoderKLWan to verify that | |
| block-level group offloading works correctly with models that have | |
| standalone encoder/decoder layers. | |
| """ | |
| import os | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| import torch | |
| # Enable faster downloads with hf_transfer if available | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| from diffusers import AutoencoderKL, AutoencoderKLWan, StableDiffusionXLPipeline | |
| # Set cache directory to models/ | |
| MODELS_DIR = Path("./models") | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| print("\n" + "#" * 80) | |
| print("# Group Offloading Test Script") | |
| print("# Testing block_level offloading with various models") | |
| print("#" * 80) | |
| print(f"\nModels will be cached to: {MODELS_DIR.absolute()}") | |
| if not torch.cuda.is_available(): | |
| print("\nERROR: CUDA is not available. This test requires a CUDA device.") | |
| sys.exit(1) | |
| print(f"\nCUDA Device: {torch.cuda.get_device_name(0)}") | |
| print(f"PyTorch Version: {torch.__version__}") | |
| onload_device = torch.device("cuda:0") | |
| offload_device = torch.device("cpu") | |
| print(f"\nOnload device: {onload_device}") | |
| print(f"Offload device: {offload_device}") | |
| def test_sdxl_vae_block_level(): | |
| """Test SDXL AutoencoderKL with block-level offloading""" | |
| print("\n" + "=" * 80) | |
| print("TEST 1: SDXL VAE with block-level offloading") | |
| print("=" * 80) | |
| try: | |
| print("Loading SDXL VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| cache_dir=MODELS_DIR, | |
| use_safetensors=True, | |
| ) | |
| print("VAE loaded") | |
| print("\nEnabling block-level group offloading...") | |
| vae.enable_group_offload( | |
| onload_device=onload_device, | |
| offload_device=offload_device, | |
| offload_type="block_level", | |
| num_blocks_per_group=1, | |
| use_stream=False, | |
| ) | |
| print("Group offloading enabled") | |
| print("\nTesting decode...") | |
| latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| decoded = vae.decode(latents).sample | |
| print(f"Test passed - output shape: {decoded.shape}") | |
| return True | |
| except RuntimeError as e: | |
| if "Input type" in str(e) and "weight type" in str(e): | |
| print(f"Test failed - device mismatch: {e}") | |
| traceback.print_exc() | |
| return False | |
| else: | |
| print(f"Test failed - unexpected error: {e}") | |
| traceback.print_exc() | |
| raise | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| raise | |
| def test_sdxl_vae_leaf_level(): | |
| """Test SDXL AutoencoderKL with leaf-level offloading""" | |
| print("\n" + "=" * 80) | |
| print("TEST 2: SDXL VAE with leaf-level offloading") | |
| print("=" * 80) | |
| try: | |
| print("Loading SDXL VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| cache_dir=MODELS_DIR, | |
| use_safetensors=True, | |
| ) | |
| print("VAE loaded") | |
| print("\nEnabling leaf-level group offloading...") | |
| vae.enable_group_offload( | |
| onload_device=onload_device, | |
| offload_device=offload_device, | |
| offload_type="leaf_level", | |
| use_stream=False, | |
| ) | |
| print("Group offloading enabled") | |
| print("\nTesting decode...") | |
| latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| decoded = vae.decode(latents).sample | |
| print(f"Test passed - output shape: {decoded.shape}") | |
| return True | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| return False | |
| def test_wan_vae_block_level(): | |
| """Test AutoencoderKLWan with block-level offloading""" | |
| print("\n" + "=" * 80) | |
| print("TEST 3: WAN VAE with block-level offloading") | |
| print("=" * 80) | |
| try: | |
| print("Loading AutoencoderKLWan...") | |
| vae = AutoencoderKLWan.from_pretrained( | |
| "Wan-AI/Wan2.1-VACE-1.3B-diffusers", | |
| subfolder="vae", | |
| torch_dtype=torch.float32, | |
| cache_dir=MODELS_DIR, | |
| ) | |
| print("VAE loaded") | |
| print("\nAnalyzing structure...") | |
| print(f" - encoder type: {type(vae.encoder).__name__}") | |
| print(f" - decoder type: {type(vae.decoder).__name__}") | |
| children_names = [name for name, _ in vae.named_children()] | |
| print(f" - named_children: {children_names}") | |
| print("\nEnabling block-level group offloading...") | |
| vae.enable_group_offload( | |
| onload_device=onload_device, | |
| offload_device=offload_device, | |
| offload_type="block_level", | |
| num_blocks_per_group=2, | |
| use_stream=False, | |
| ) | |
| print("Group offloading enabled") | |
| print("\nTesting encode...") | |
| test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| encoded = vae.encode(test_input) | |
| print(f"Test passed - output device: {encoded.latent_dist.sample().device}") | |
| return True | |
| except Exception as e: | |
| if "Wan-AI" in str(e) or "not found" in str(e).lower(): | |
| print(f"Model not available: {e}") | |
| return None | |
| elif "Input type" in str(e) and "weight type" in str(e): | |
| print(f"Test failed - device mismatch: {e}") | |
| traceback.print_exc() | |
| return False | |
| else: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| raise | |
| def test_wan_vae_leaf_level(): | |
| """Test AutoencoderKLWan with leaf-level offloading""" | |
| print("\n" + "=" * 80) | |
| print("TEST 4: WAN VAE with leaf-level offloading") | |
| print("=" * 80) | |
| try: | |
| print("Loading AutoencoderKLWan...") | |
| vae = AutoencoderKLWan.from_pretrained( | |
| "Wan-AI/Wan2.1-VACE-1.3B-diffusers", | |
| subfolder="vae", | |
| torch_dtype=torch.float32, | |
| cache_dir=MODELS_DIR, | |
| ) | |
| print("VAE loaded") | |
| print("\nEnabling leaf-level group offloading...") | |
| vae.enable_group_offload( | |
| onload_device=onload_device, | |
| offload_device=offload_device, | |
| offload_type="leaf_level", | |
| use_stream=False, | |
| ) | |
| print("Group offloading enabled") | |
| print("\nTesting encode...") | |
| test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| encoded = vae.encode(test_input) | |
| print(f"Test passed - output device: {encoded.latent_dist.sample().device}") | |
| return True | |
| except Exception as e: | |
| if "Wan-AI" in str(e) or "not found" in str(e).lower(): | |
| print(f"Model not available: {e}") | |
| return None | |
| else: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| return False | |
| def test_sdxl_pipeline_block_level(): | |
| """Test full SDXL pipeline with block-level offloading""" | |
| print("\n" + "=" * 80) | |
| print("TEST 5: SDXL Pipeline with block-level offloading") | |
| print("=" * 80) | |
| try: | |
| print("Loading SDXL pipeline...") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=torch.bfloat16, | |
| variant="fp16", | |
| use_safetensors=True, | |
| cache_dir=MODELS_DIR, | |
| ) | |
| print("Pipeline loaded") | |
| print("\nEnabling block-level group offloading...") | |
| pipe.enable_group_offload( | |
| onload_device=onload_device, | |
| offload_device=offload_device, | |
| offload_type="block_level", | |
| num_blocks_per_group=1, | |
| use_stream=False, | |
| ) | |
| print("Group offloading enabled") | |
| print("\nGenerating image...") | |
| _image = pipe( | |
| prompt="A beautiful painting of a futuristic cityscape at sunset", | |
| width=512, | |
| height=512, | |
| num_inference_steps=5, | |
| generator=torch.Generator(device=onload_device).manual_seed(42), | |
| ).images[0] | |
| print("Test passed") | |
| return True | |
| except RuntimeError as e: | |
| if "should be the same" in str(e) or "device" in str(e).lower(): | |
| print(f"Test failed - device mismatch: {e}") | |
| traceback.print_exc() | |
| return False | |
| else: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| raise | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| traceback.print_exc() | |
| raise | |
| def main(): | |
| results = {} | |
| # Test SDXL VAE | |
| results["sdxl_vae_block_level"] = test_sdxl_vae_block_level() | |
| results["sdxl_vae_leaf_level"] = test_sdxl_vae_leaf_level() | |
| # Test WAN VAE | |
| wan_block_result = test_wan_vae_block_level() | |
| if wan_block_result is not None: | |
| results["wan_vae_block_level"] = wan_block_result | |
| wan_leaf_result = test_wan_vae_leaf_level() | |
| if wan_leaf_result is not None: | |
| results["wan_vae_leaf_level"] = wan_leaf_result | |
| # Test full pipeline if requested | |
| if os.environ.get("RUN_FULL_TESTS", "0") == "1": | |
| results["sdxl_pipeline_block_level"] = test_sdxl_pipeline_block_level() | |
| # Summary | |
| print("\n" + "=" * 80) | |
| print("SUMMARY") | |
| print("=" * 80) | |
| for test_name, passed in results.items(): | |
| status = "PASSED" if passed else "FAILED" | |
| print(f"{test_name:35s}: {status}") | |
| print("=" * 80) | |
| # Return exit code based on results | |
| failed = [name for name, passed in results.items() if not passed] | |
| if failed: | |
| print(f"\nFailed tests: {', '.join(failed)}") | |
| sys.exit(1) | |
| else: | |
| print("\nAll tests passed") | |
| sys.exit(0) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Standalone script output for commit 59b6b678295214b70f6ecaa3f95129b76baf50d8