[SPIRV] Fix getNumScalarOrVectorTotalBitWidth to handle OpTypeBool type#186296
[SPIRV] Fix getNumScalarOrVectorTotalBitWidth to handle OpTypeBool type#186296
Conversation
|
@llvm/pr-subscribers-backend-spir-v Author: woruyu (woruyu) ChangesSummaryThis PR resolves #185815. Problem
Full diff: https://github.com/llvm/llvm-project/pull/186296.diff 2 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9a85634c82626..5eb8928761f68 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1428,8 +1428,9 @@ unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
}
return Type->getOpcode() == SPIRV::OpTypeInt ||
- Type->getOpcode() == SPIRV::OpTypeFloat
- ? NumElements * Type->getOperand(1).getImm()
+ Type->getOpcode() == SPIRV::OpTypeFloat ||
+ Type->getOpcode() == SPIRV::OpTypeBool
+ ? NumElements * getScalarOrVectorBitWidth(Type)
: 0;
}
diff --git a/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll b/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll
new file mode 100644
index 0000000000000..8ff60ba68af60
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/masked-store-bool-mask.ll
@@ -0,0 +1,91 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+define void @kernelAddConstant(<8 x i1> %0) {
+; CHECK-LABEL: kernelAddConstant
+; CHECK: %44 = OpFunction %20 None %21 ; -- Begin function kernelAddConstant
+; CHECK-NEXT: %45 = OpFunctionParameter %19
+; CHECK-NEXT: %78 = OpLabel
+; CHECK-NEXT: %46 = OpBitcast %24 %45
+; CHECK-NEXT: %47 = OpBitwiseAnd %24 %46 %43
+; CHECK-NEXT: %48 = OpINotEqual %18 %47 %42
+; CHECK-NEXT: %49 = OpBitcast %23 %41
+; CHECK-NEXT: %50 = OpInBoundsPtrAccessChain %23 %49 %40
+; CHECK-NEXT: %51 = OpBitcast %23 %41
+; CHECK-NEXT: %52 = OpInBoundsPtrAccessChain %23 %51 %39
+; CHECK-NEXT: %53 = OpBitcast %23 %41
+; CHECK-NEXT: %54 = OpInBoundsPtrAccessChain %23 %53 %38
+; CHECK-NEXT: %55 = OpBitcast %23 %41
+; CHECK-NEXT: %56 = OpInBoundsPtrAccessChain %23 %55 %37
+; CHECK-NEXT: %57 = OpBitcast %23 %41
+; CHECK-NEXT: %58 = OpInBoundsPtrAccessChain %23 %57 %36
+; CHECK-NEXT: %59 = OpBitcast %23 %41
+; CHECK-NEXT: %60 = OpInBoundsPtrAccessChain %23 %59 %35
+; CHECK-NEXT: %61 = OpBitcast %23 %41
+; CHECK-NEXT: %62 = OpInBoundsPtrAccessChain %23 %61 %34
+; CHECK-NEXT: OpBranchConditional %48 %2 %3
+; CHECK-NEXT: %2 = OpLabel
+; CHECK-NEXT: %63 = OpBitcast %23 %41
+; CHECK-NEXT: OpStore %63 %33 Aligned 1
+; CHECK-NEXT: OpBranch %3
+; CHECK-NEXT: %3 = OpLabel
+; CHECK-NEXT: %64 = OpBitwiseAnd %24 %46 %32
+; CHECK-NEXT: %65 = OpINotEqual %18 %64 %42
+; CHECK-NEXT: OpBranchConditional %65 %4 %5
+; CHECK-NEXT: %4 = OpLabel
+; CHECK-NEXT: OpStore %50 %33 Aligned 1
+; CHECK-NEXT: OpBranch %5
+; CHECK-NEXT: %5 = OpLabel
+; CHECK-NEXT: %66 = OpBitwiseAnd %24 %46 %31
+; CHECK-NEXT: %67 = OpINotEqual %18 %66 %42
+; CHECK-NEXT: OpBranchConditional %67 %6 %7
+; CHECK-NEXT: %6 = OpLabel
+; CHECK-NEXT: OpStore %52 %33 Aligned 1
+; CHECK-NEXT: OpBranch %7
+; CHECK-NEXT: %7 = OpLabel
+; CHECK-NEXT: %68 = OpBitwiseAnd %24 %46 %30
+; CHECK-NEXT: %69 = OpINotEqual %18 %68 %42
+; CHECK-NEXT: OpBranchConditional %69 %8 %9
+; CHECK-NEXT: %8 = OpLabel
+; CHECK-NEXT: OpStore %54 %33 Aligned 1
+; CHECK-NEXT: OpBranch %9
+; CHECK-NEXT: %9 = OpLabel
+; CHECK-NEXT: %70 = OpBitwiseAnd %24 %46 %29
+; CHECK-NEXT: %71 = OpINotEqual %18 %70 %42
+; CHECK-NEXT: OpBranchConditional %71 %10 %11
+; CHECK-NEXT: %10 = OpLabel
+; CHECK-NEXT: OpStore %56 %33 Aligned 1
+; CHECK-NEXT: OpBranch %11
+; CHECK-NEXT: %11 = OpLabel
+; CHECK-NEXT: %72 = OpBitwiseAnd %24 %46 %28
+; CHECK-NEXT: %73 = OpINotEqual %18 %72 %42
+; CHECK-NEXT: OpBranchConditional %73 %12 %13
+; CHECK-NEXT: %12 = OpLabel
+; CHECK-NEXT: OpStore %58 %33 Aligned 1
+; CHECK-NEXT: OpBranch %13
+; CHECK-NEXT: %13 = OpLabel
+; CHECK-NEXT: %74 = OpBitwiseAnd %24 %46 %27
+; CHECK-NEXT: %75 = OpINotEqual %18 %74 %42
+; CHECK-NEXT: OpBranchConditional %75 %14 %15
+; CHECK-NEXT: %14 = OpLabel
+; CHECK-NEXT: OpStore %60 %33 Aligned 1
+; CHECK-NEXT: OpBranch %15
+; CHECK-NEXT: %15 = OpLabel
+; CHECK-NEXT: %76 = OpBitwiseAnd %24 %46 %26
+; CHECK-NEXT: %77 = OpINotEqual %18 %76 %42
+; CHECK-NEXT: OpBranchConditional %77 %16 %17
+; CHECK-NEXT: %16 = OpLabel
+; CHECK-NEXT: OpStore %62 %33 Aligned 1
+; CHECK-NEXT: OpBranch %17
+; CHECK-NEXT: %17 = OpLabel
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpFunctionEnd
+ call void @llvm.masked.store.v8i32.p1(<8 x i32> zeroinitializer, ptr addrspace(1) null, <8 x i1> %0)
+ ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0
+
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) }
|
jmmartinez
left a comment
There was a problem hiding this comment.
The patch looks reasonable to me. I left some rather minor remarks.
However It seems the CI tests failed:
Failed Tests (2):
LLVM :: CodeGen/SPIRV/masked-store-bool-mask.ll
LLVM :: CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll
I'd like you to check where the spirv-val fails are coming from before we land this.
| ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write) | ||
| declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0 | ||
|
|
||
| attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) } |
There was a problem hiding this comment.
No need to put the attributes, the compiler can deduce them. Also, the ; Function Attrs is a comment so we can safely remove it.
| ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: write) | |
| declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) #0 | |
| attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: write) } | |
| declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1) captures(none), <8 x i1>) |
| @@ -0,0 +1,91 @@ | |||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 | |||
There was a problem hiding this comment.
Same comment as in https://github.com/llvm/llvm-project/pull/186028/changes#r2923686391
| @@ -0,0 +1,91 @@ | |||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 | |||
| ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s | |||
| ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} | |||
There was a problem hiding this comment.
When I ran this through spirv-val I've got:
error: line 56: Expected input to be a pointer or int or float vector or scalar: Bitcast
%46 = OpBitcast %uchar %45
Have you looked into it?
There was a problem hiding this comment.
I found that the generated SPIR-V contains:
%18 = OpTypeBool
%19 = OpTypeVector %18 8
%24 = OpTypeInt 8 0
%45 = OpFunctionParameter %19
%46 = OpBitcast %24 %45
and spirv-val rejects it because OpBitcast does not accept OpTypeBool or vector<bool> operands.
| ; CHECK-NEXT: %17 = OpLabel | ||
| ; CHECK-NEXT: OpReturn | ||
| ; CHECK-NEXT: OpFunctionEnd | ||
| call void @llvm.masked.store.v8i32.p1(<8 x i32> zeroinitializer, ptr addrspace(1) null, <8 x i1> %0) |
There was a problem hiding this comment.
Could you use a vector <8 x i32> <i32 1, i32 2, i32 3, ...> instead of zeroinitializer ?
When I was reading the generated code I saw the conditional store 0 to the right address. It'd help to distinguish each of the generated stores for each elements of the vector. Otherwise they all look the same.
|
| Type->getOpcode() == SPIRV::OpTypeFloat | ||
| ? NumElements * Type->getOperand(1).getImm() | ||
| Type->getOpcode() == SPIRV::OpTypeFloat || | ||
| Type->getOpcode() == SPIRV::OpTypeBool |
There was a problem hiding this comment.
OpTypeBool in SPIR-V is not the same as i1 from LLVM IR (so frankly, the fact that we are mapping them anyway is a bit speculative, but just works in most of the use cases). But since it has the following line in the definition:
There is no physical size or bit pattern defined for these values.
I feel uneasy to add it to getNumScalarOrVectorTotalBitWidth.
There was a problem hiding this comment.
Thanks for the review.
I’m a little confused about the best way to address this. getNumScalarOrVectorTotalBitWidth() is only used by isBitcastCompatible(), where we need some notion of total bit width in order to check whether Bits1 == Bits2.
Since OpTypeBool does not have a defined physical size in SPIR-V, would you prefer that I handle it directly in isBitcastCompatible() instead of extending this helper?
There was a problem hiding this comment.
I'm actually investigating a similar issue at the moment. The solution I came up with is to pre-process bool vector bitcasts by decomposing such vectors:
bitcast to iN => extract each bool element, zext to integer, shift left, OR together
bitcast iN to => AND with each bit mask, icmp ne 0, insert into result vector
There was a problem hiding this comment.
Other option is to resolve masked.load/store before scalarization, avoiding invalid bitcast insertion all along. While I have a patch for vector decomposition - I'm still bargaining with myself whether one or another approach should be productized :D
There was a problem hiding this comment.
I've convinced myself, that element-wise decomposition is fine, created #187960
|
I have a question about the expected handling here. My current understanding is that this should probably not be fixed by relaxing selectIntrinsic. Instead, <-> iN likely needs to be lowered as pack/unpack before SPIR-V bitcast emission. Would the right fix be to stop creating llvm.spv.bitcast for bool-vector cases in SPIRVEmitIntrinsics, or should this be expanded in legalizer instead? |
s-perron
left a comment
There was a problem hiding this comment.
I agree. Bitcast on bool is not valid spir-v. We should not be trying to generate that code. Returning 0 for bool seems appropriate.
Summary
This PR resolves #185815.
Problem
getNumScalarOrVectorTotalBitWidthpreviously only handledOpTypeIntandOpTypeFloat. For boolean scalar/vector types it returned0, which could break SPIR-V codegen for IR using boolean masks, such asllvm.masked.store.