Skip to content

[SPIRV] Fix getNumScalarOrVectorTotalBitWidth to handle OpTypeBool type#186296

Open
woruyu wants to merge 1 commit intollvm:mainfrom
woruyu:fix/isBitcastCompatible
Open

[SPIRV] Fix getNumScalarOrVectorTotalBitWidth to handle OpTypeBool type#186296
woruyu wants to merge 1 commit intollvm:mainfrom
woruyu:fix/isBitcastCompatible

Conversation

@woruyu
Copy link
Copy Markdown
Member

@woruyu woruyu commented Mar 13, 2026

Summary

This PR resolves #185815.

Problem

getNumScalarOrVectorTotalBitWidth previously only handled OpTypeInt and OpTypeFloat. For boolean scalar/vector types it returned 0, which could break SPIR-V codegen for IR using boolean masks, such as llvm.masked.store.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Mar 13, 2026

@llvm/pr-subscribers-backend-spir-v

Author: woruyu (woruyu)

Changes

Summary

This PR resolves #185815.

Problem

getNumScalarOrVectorTotalBitWidth previously only handled OpTypeInt and OpTypeFloat. For boolean scalar/vector types it returned 0, which could break SPIR-V codegen for IR using boolean masks, such as llvm.masked.store.


Full diff: https://github.com/llvm/llvm-project/pull/186296.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+3-2)
  • (added) llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll (+91)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9a85634c82626..5eb8928761f68 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1428,8 +1428,9 @@ unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
   }
   return Type->getOpcode() == SPIRV::OpTypeInt ||
-                 Type->getOpcode() == SPIRV::OpTypeFloat
-             ? NumElements * Type->getOperand(1).getImm()
+                 Type->getOpcode() == SPIRV::OpTypeFloat ||
+                 Type->getOpcode() == SPIRV::OpTypeBool
+             ? NumElements * getScalarOrVectorBitWidth(Type)
              : 0;
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll b/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll
new file mode 100644
index 0000000000000..8ff60ba68af60
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+define void @kernelAddConstant(<8 x i1> %0) {
+; CHECK-LABEL: kernelAddConstant
+; CHECK:       %44 = OpFunction %20 None %21 ; -- Begin function kernelAddConstant
+; CHECK-NEXT:    %45 = OpFunctionParameter %19
+; CHECK-NEXT:    %78 = OpLabel
+; CHECK-NEXT:    %46 = OpBitcast %24 %45
+; CHECK-NEXT:    %47 = OpBitwiseAnd %24 %46 %43
+; CHECK-NEXT:    %48 = OpINotEqual %18 %47 %42
+; CHECK-NEXT:    %49 = OpBitcast %23 %41
+; CHECK-NEXT:    %50 = OpInBoundsPtrAccessChain %23 %49 %40
+; CHECK-NEXT:    %51 = OpBitcast %23 %41
+; CHECK-NEXT:    %52 = OpInBoundsPtrAccessChain %23 %51 %39
+; CHECK-NEXT:    %53 = OpBitcast %23 %41
+; CHECK-NEXT:    %54 = OpInBoundsPtrAccessChain %23 %53 %38
+; CHECK-NEXT:    %55 = OpBitcast %23 %41
+; CHECK-NEXT:    %56 = OpInBoundsPtrAccessChain %23 %55 %37
+; CHECK-NEXT:    %57 = OpBitcast %23 %41
+; CHECK-NEXT:    %58 = OpInBoundsPtrAccessChain %23 %57 %36
+; CHECK-NEXT:    %59 = OpBitcast %23 %41
+; CHECK-NEXT:    %60 = OpInBoundsPtrAccessChain %23 %59 %35
+; CHECK-NEXT:    %61 = OpBitcast %23 %41
+; CHECK-NEXT:    %62 = OpInBoundsPtrAccessChain %23 %61 %34
+; CHECK-NEXT:    OpBranchConditional %48 %2 %3
+; CHECK-NEXT:    %2 = OpLabel
+; CHECK-NEXT:    %63 = OpBitcast %23 %41
+; CHECK-NEXT:    OpStore %63 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %3
+; CHECK-NEXT:    %3 = OpLabel
+; CHECK-NEXT:    %64 = OpBitwiseAnd %24 %46 %32
+; CHECK-NEXT:    %65 = OpINotEqual %18 %64 %42
+; CHECK-NEXT:    OpBranchConditional %65 %4 %5
+; CHECK-NEXT:    %4 = OpLabel
+; CHECK-NEXT:    OpStore %50 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %5
+; CHECK-NEXT:    %5 = OpLabel
+; CHECK-NEXT:    %66 = OpBitwiseAnd %24 %46 %31
+; CHECK-NEXT:    %67 = OpINotEqual %18 %66 %42
+; CHECK-NEXT:    OpBranchConditional %67 %6 %7
+; CHECK-NEXT:    %6 = OpLabel
+; CHECK-NEXT:    OpStore %52 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %7
+; CHECK-NEXT:    %7 = OpLabel
+; CHECK-NEXT:    %68 = OpBitwiseAnd %24 %46 %30
+; CHECK-NEXT:    %69 = OpINotEqual %18 %68 %42
+; CHECK-NEXT:    OpBranchConditional %69 %8 %9
+; CHECK-NEXT:    %8 = OpLabel
+; CHECK-NEXT:    OpStore %54 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %9
+; CHECK-NEXT:    %9 = OpLabel
+; CHECK-NEXT:    %70 = OpBitwiseAnd %24 %46 %29
+; CHECK-NEXT:    %71 = OpINotEqual %18 %70 %42
+; CHECK-NEXT:    OpBranchConditional %71 %10 %11
+; CHECK-NEXT:    %10 = OpLabel
+; CHECK-NEXT:    OpStore %56 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %11
+; CHECK-NEXT:    %11 = OpLabel
+; CHECK-NEXT:    %72 = OpBitwiseAnd %24 %46 %28
+; CHECK-NEXT:    %73 = OpINotEqual %18 %72 %42
+; CHECK-NEXT:    OpBranchConditional %73 %12 %13
+; CHECK-NEXT:    %12 = OpLabel
+; CHECK-NEXT:    OpStore %58 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %13
+; CHECK-NEXT:    %13 = OpLabel
+; CHECK-NEXT:    %74 = OpBitwiseAnd %24 %46 %27
+; CHECK-NEXT:    %75 = OpINotEqual %18 %74 %42
+; CHECK-NEXT:    OpBranchConditional %75 %14 %15
+; CHECK-NEXT:    %14 = OpLabel
+; CHECK-NEXT:    OpStore %60 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %15
+; CHECK-NEXT:    %15 = OpLabel
+; CHECK-NEXT:    %76 = OpBitwiseAnd %24 %46 %26
+; CHECK-NEXT:    %77 = OpINotEqual %18 %76 %42
+; CHECK-NEXT:    OpBranchConditional %77 %16 %17
+; CHECK-NEXT:    %16 = OpLabel
+; CHECK-NEXT:    OpStore %62 %33 Aligned 1
+; CHECK-NEXT:    OpBranch %17
+; CHECK-NEXT:    %17 = OpLabel
+; CHECK-NEXT:    OpReturn
+; CHECK-NEXT:    OpFunctionEnd
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> zeroinitializer, ptr addrspace(1) null, <8 x i1> %0)
+  ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0
+
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) }

Copy link
Copy Markdown
Contributor

@jmmartinez jmmartinez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch looks reasonable to me. I left some rather minor remarks.

However It seems the CI tests failed:

Failed Tests (2):
  LLVM :: CodeGen/SPIRV/masked-store-bool-mask.ll
  LLVM :: CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll

I'd like you to check where the spirv-val fails are coming from before we land this.

Comment on lines +88 to +91
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write)
declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0

attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to put the attributes, the compiler can deduce them. Also, the ; Function Attrs is a comment so we can safely remove it.

Suggested change
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write)
declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0
attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) }
declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>)

@@ -0,0 +1,91 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,91 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I ran this through spirv-val I've got:

error: line 56: Expected input to be a pointer or int or float vector or scalar: Bitcast
  %46 = OpBitcast %uchar %45

Have you looked into it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that the generated SPIR-V contains:

%18 = OpTypeBool
%19 = OpTypeVector %18 8 
%24 = OpTypeInt 8 0 
%45 = OpFunctionParameter %19
%46 = OpBitcast %24 %45 

and spirv-val rejects it because OpBitcast does not accept OpTypeBool or vector<bool> operands.

; CHECK-NEXT: %17 = OpLabel
; CHECK-NEXT: OpReturn
; CHECK-NEXT: OpFunctionEnd
call void @llvm.masked.store.v8i32.p1(<8 x i32> zeroinitializer, ptr addrspace(1) null, <8 x i1> %0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use a vector <8 x i32> <i32 1, i32 2, i32 3, ...> instead of zeroinitializer ?

When I was reading the generated code I saw the conditional store 0 to the right address. It'd help to distinguish each of the generated stores for each elements of the vector. Otherwise they all look the same.

@jmmartinez
Copy link
Copy Markdown
Contributor

jmmartinez commented Mar 13, 2026

The patch looks reasonable to me. I left some rather minor remarks.

However It seems the CI tests failed:

Failed Tests (2):
  LLVM :: CodeGen/SPIRV/masked-store-bool-mask.ll
  LLVM :: CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll

I'd like you to check where the spirv-val fails are coming from before we land this.

CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll one seems to be due to a spirv-val update unrelated to your PR (this one: KhronosGroup/SPIRV-Tools#6585). You can ignore it. I filed an issue in #186344.

Type->getOpcode() == SPIRV::OpTypeFloat
? NumElements * Type->getOperand(1).getImm()
Type->getOpcode() == SPIRV::OpTypeFloat ||
Type->getOpcode() == SPIRV::OpTypeBool
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpTypeBool in SPIR-V is not the same as i1 from LLVM IR (so frankly, the fact that we are mapping them anyway is a bit speculative, but just works in most of the use cases). But since it has the following line in the definition:
There is no physical size or bit pattern defined for these values.
I feel uneasy to add it to getNumScalarOrVectorTotalBitWidth.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review.
I’m a little confused about the best way to address this. getNumScalarOrVectorTotalBitWidth() is only used by isBitcastCompatible(), where we need some notion of total bit width in order to check whether Bits1 == Bits2.
Since OpTypeBool does not have a defined physical size in SPIR-V, would you prefer that I handle it directly in isBitcastCompatible() instead of extending this helper?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually investigating a similar issue at the moment. The solution I came up with is to pre-process bool vector bitcasts by decomposing such vectors:
bitcast to iN => extract each bool element, zext to integer, shift left, OR together
bitcast iN to => AND with each bit mask, icmp ne 0, insert into result vector

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other option is to resolve masked.load/store before scalarization, avoiding invalid bitcast insertion all along. While I have a patch for vector decomposition - I'm still bargaining with myself whether one or another approach should be productized :D

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've convinced myself, that element-wise decomposition is fine, created #187960

@woruyu
Copy link
Copy Markdown
Member Author

woruyu commented Mar 16, 2026

I have a question about the expected handling here.
This path currently goes through
masked.store -> scalarize -> bitcast <8 x i1> to i8 -> llvm.spv.bitcast.i8.v8i1,
but spirv-val does not accept SPIR-V bitcast on bool vectors, and selectIntrinsic also requires isBitcastCompatible.

My current understanding is that this should probably not be fixed by relaxing selectIntrinsic. Instead, <-> iN likely needs to be lowered as pack/unpack before SPIR-V bitcast emission.

Would the right fix be to stop creating llvm.spv.bitcast for bool-vector cases in SPIRVEmitIntrinsics, or should this be expanded in legalizer instead?

Copy link
Copy Markdown
Contributor

@s-perron s-perron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Bitcast on bool is not valid spir-v. We should not be trying to generate that code. Returning 0 for bool seems appropriate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] Crash in SPIRVInstructionSelector::selectIntrinsic during --gpu-lower-to-xevm-pipeline

5 participants