Skip to content

Commit cc50469

Browse files
authored
fix: compile flags for trtllm fmha_v2 (#2175)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Removed noisy runtime console prints during build/generation. * Updated CUDA compiler requirements to target CUDA 12 and added a new compiler flag for compatibility. * **Bug Fixes** * Added an early check that raises a clear error on unsupported GPU devices (SM120a), preventing misruns. * **Tests** * Test now skips automatically when the required SM120a GPU support is not present. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent eec483b commit cc50469

File tree

5 files changed

+9
-3
lines changed

5 files changed

+9
-3
lines changed

.pre-commit-config.yaml

100644100755
File mode changed.

flashinfer/jit/attention/fmha_v2/generator_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,10 +3711,10 @@ def generate_files(specs_names):
37113711
]
37123712
if "CUDA_PATH" in os.environ:
37133713
cmd[0] = os.environ["CUDA_PATH"] + "/bin/" + cmd[0]
3714-
print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd)))
3714+
# print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd)))
37153715
process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
37163716
output, error = process.communicate()
3717-
print('Running "bin/print_traits.exe":')
3717+
# print('Running "bin/print_traits.exe":')
37183718
process = subprocess.Popen(
37193719
"bin/print_traits.exe", stdin=subprocess.PIPE, stdout=subprocess.PIPE
37203720
)

flashinfer/jit/attention/modules.py

100644100755
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1901,9 +1901,10 @@ def gen_trtllm_fmha_v2_module() -> JitSpec:
19011901
source_paths = kernel_paths + [binding_source_path]
19021902

19031903
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
1904-
supported_major_versions=[10, 11, 12]
1904+
supported_major_versions=[12]
19051905
)
19061906
nvcc_flags.append(f"-I{jit_env.FLASHINFER_CSRC_DIR / 'fmha_v2'}")
1907+
nvcc_flags.append("-Wno-deprecated-gpu-targets")
19071908

19081909
return gen_jit_spec(
19091910
uri,

flashinfer/prefill.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,6 +3603,8 @@ def fmha_v2_prefill_deepseek(
36033603
If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor.
36043604
If return_lse is False, the output will be a single tensor.
36053605
"""
3606+
if not is_sm120a_supported(query.device):
3607+
raise ValueError("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.")
36063608
assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
36073609
"currently only support deepseek r1 192 query and 128 value"
36083610
)

tests/attention/test_fmha_v2_prefill_deepseek.py

100644100755
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from flashinfer.prefill import fmha_v2_prefill_deepseek
77
from tests.utils_fp8 import to_float8
8+
from flashinfer.utils import is_sm120a_supported
89

910

1011
def attention_ref(
@@ -56,6 +57,8 @@ def attention_ref(
5657
def test_fmha_v2_prefill_deepseek(
5758
batch_size, num_heads, head_dim_qk, head_dim_v, seq_len, qkv_dtype, o_dtype
5859
):
60+
if not is_sm120a_supported(torch.device("cuda")):
61+
pytest.skip("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.")
5962
torch.manual_seed(42)
6063

6164
def initialize_tensors(batch_size, num_heads, head_dim_qk, head_dim_v, seq_len):

0 commit comments

Comments
 (0)