""" Test script for group offloading with block_level for various models. Tests both AutoencoderKL (SDXL) and AutoencoderKLWan to verify that block-level group offloading works correctly with models that have standalone encoder/decoder layers. """ import os import sys import traceback from pathlib import Path import torch # Enable faster downloads with hf_transfer if available os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" from diffusers import AutoencoderKL, AutoencoderKLWan, StableDiffusionXLPipeline # Set cache directory to models/ MODELS_DIR = Path("./models") MODELS_DIR.mkdir(exist_ok=True) print("\n" + "#" * 80) print("# Group Offloading Test Script") print("# Testing block_level offloading with various models") print("#" * 80) print(f"\nModels will be cached to: {MODELS_DIR.absolute()}") if not torch.cuda.is_available(): print("\nERROR: CUDA is not available. This test requires a CUDA device.") sys.exit(1) print(f"\nCUDA Device: {torch.cuda.get_device_name(0)}") print(f"PyTorch Version: {torch.__version__}") onload_device = torch.device("cuda:0") offload_device = torch.device("cpu") print(f"\nOnload device: {onload_device}") print(f"Offload device: {offload_device}") def test_sdxl_vae_block_level(): """Test SDXL AutoencoderKL with block-level offloading""" print("\n" + "=" * 80) print("TEST 1: SDXL VAE with block-level offloading") print("=" * 80) try: print("Loading SDXL VAE...") vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=MODELS_DIR, use_safetensors=True, ) print("VAE loaded") print("\nEnabling block-level group offloading...") vae.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False, ) print("Group offloading enabled") print("\nTesting decode...") latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16) with torch.no_grad(): decoded = vae.decode(latents).sample print(f"Test passed - output shape: {decoded.shape}") return True except RuntimeError as e: if "Input type" in str(e) and "weight type" in str(e): print(f"Test failed - device mismatch: {e}") traceback.print_exc() return False else: print(f"Test failed - unexpected error: {e}") traceback.print_exc() raise except Exception as e: print(f"Test failed: {e}") traceback.print_exc() raise def test_sdxl_vae_leaf_level(): """Test SDXL AutoencoderKL with leaf-level offloading""" print("\n" + "=" * 80) print("TEST 2: SDXL VAE with leaf-level offloading") print("=" * 80) try: print("Loading SDXL VAE...") vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=MODELS_DIR, use_safetensors=True, ) print("VAE loaded") print("\nEnabling leaf-level group offloading...") vae.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=False, ) print("Group offloading enabled") print("\nTesting decode...") latents = torch.randn(1, 4, 64, 64, device=onload_device, dtype=torch.bfloat16) with torch.no_grad(): decoded = vae.decode(latents).sample print(f"Test passed - output shape: {decoded.shape}") return True except Exception as e: print(f"Test failed: {e}") traceback.print_exc() return False def test_wan_vae_block_level(): """Test AutoencoderKLWan with block-level offloading""" print("\n" + "=" * 80) print("TEST 3: WAN VAE with block-level offloading") print("=" * 80) try: print("Loading AutoencoderKLWan...") vae = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-VACE-1.3B-diffusers", subfolder="vae", torch_dtype=torch.float32, cache_dir=MODELS_DIR, ) print("VAE loaded") print("\nAnalyzing structure...") print(f" - encoder type: {type(vae.encoder).__name__}") print(f" - decoder type: {type(vae.decoder).__name__}") children_names = [name for name, _ in vae.named_children()] print(f" - named_children: {children_names}") print("\nEnabling block-level group offloading...") vae.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=2, use_stream=False, ) print("Group offloading enabled") print("\nTesting encode...") test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32) with torch.no_grad(): encoded = vae.encode(test_input) print(f"Test passed - output device: {encoded.latent_dist.sample().device}") return True except Exception as e: if "Wan-AI" in str(e) or "not found" in str(e).lower(): print(f"Model not available: {e}") return None elif "Input type" in str(e) and "weight type" in str(e): print(f"Test failed - device mismatch: {e}") traceback.print_exc() return False else: print(f"Test failed: {e}") traceback.print_exc() raise def test_wan_vae_leaf_level(): """Test AutoencoderKLWan with leaf-level offloading""" print("\n" + "=" * 80) print("TEST 4: WAN VAE with leaf-level offloading") print("=" * 80) try: print("Loading AutoencoderKLWan...") vae = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-VACE-1.3B-diffusers", subfolder="vae", torch_dtype=torch.float32, cache_dir=MODELS_DIR, ) print("VAE loaded") print("\nEnabling leaf-level group offloading...") vae.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=False, ) print("Group offloading enabled") print("\nTesting encode...") test_input = torch.randn(1, 3, 1, 64, 64, device=onload_device, dtype=torch.float32) with torch.no_grad(): encoded = vae.encode(test_input) print(f"Test passed - output device: {encoded.latent_dist.sample().device}") return True except Exception as e: if "Wan-AI" in str(e) or "not found" in str(e).lower(): print(f"Model not available: {e}") return None else: print(f"Test failed: {e}") traceback.print_exc() return False def test_sdxl_pipeline_block_level(): """Test full SDXL pipeline with block-level offloading""" print("\n" + "=" * 80) print("TEST 5: SDXL Pipeline with block-level offloading") print("=" * 80) try: print("Loading SDXL pipeline...") pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True, cache_dir=MODELS_DIR, ) print("Pipeline loaded") print("\nEnabling block-level group offloading...") pipe.enable_group_offload( onload_device=onload_device, offload_device=offload_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False, ) print("Group offloading enabled") print("\nGenerating image...") _image = pipe( prompt="A beautiful painting of a futuristic cityscape at sunset", width=512, height=512, num_inference_steps=5, generator=torch.Generator(device=onload_device).manual_seed(42), ).images[0] print("Test passed") return True except RuntimeError as e: if "should be the same" in str(e) or "device" in str(e).lower(): print(f"Test failed - device mismatch: {e}") traceback.print_exc() return False else: print(f"Test failed: {e}") traceback.print_exc() raise except Exception as e: print(f"Test failed: {e}") traceback.print_exc() raise def main(): results = {} # Test SDXL VAE results["sdxl_vae_block_level"] = test_sdxl_vae_block_level() results["sdxl_vae_leaf_level"] = test_sdxl_vae_leaf_level() # Test WAN VAE wan_block_result = test_wan_vae_block_level() if wan_block_result is not None: results["wan_vae_block_level"] = wan_block_result wan_leaf_result = test_wan_vae_leaf_level() if wan_leaf_result is not None: results["wan_vae_leaf_level"] = wan_leaf_result # Test full pipeline if requested if os.environ.get("RUN_FULL_TESTS", "0") == "1": results["sdxl_pipeline_block_level"] = test_sdxl_pipeline_block_level() # Summary print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) for test_name, passed in results.items(): status = "PASSED" if passed else "FAILED" print(f"{test_name:35s}: {status}") print("=" * 80) # Return exit code based on results failed = [name for name, passed in results.items() if not passed] if failed: print(f"\nFailed tests: {', '.join(failed)}") sys.exit(1) else: print("\nAll tests passed") sys.exit(0) if __name__ == "__main__": main()