diff --git a/BUILD.gn b/BUILD.gn index c69132ddcd..abf9be617f 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -630,6 +630,8 @@ static_library("spvtools_opt") { "source/opt/tree_iterator.h", "source/opt/trim_capabilities_pass.cpp", "source/opt/trim_capabilities_pass.h", + "source/opt/trim_variable_pointers_capabilities_pass.cpp", + "source/opt/trim_variable_pointers_capabilities_pass.h", "source/opt/type_manager.cpp", "source/opt/type_manager.h", "source/opt/types.cpp", diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index fd4527bf50..5fd76317ce 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -32,6 +32,13 @@ class Pass; struct DescriptorSetAndBinding; } // namespace opt +enum class SSARewriteMode { + None, + All, + OpaqueOnly, + SpecialTypes, +}; + // C++ interface for SPIR-V optimization functionalities. It wraps the context // (including target environment and the corresponding SPIR-V grammar) and // provides methods for registering optimization passes and optimizing. @@ -102,6 +109,8 @@ class SPIRV_TOOLS_EXPORT Optimizer { // interface are considered live and are not eliminated. Optimizer& RegisterPerformancePasses(); Optimizer& RegisterPerformancePasses(bool preserve_interface); + Optimizer& RegisterPerformancePassesFastCompile(); + Optimizer& RegisterPerformancePassesFastCompile(bool preserve_interface); // Registers passes that attempt to improve the size of generated code. // This sequence of passes is subject to constant review and will change @@ -125,6 +134,10 @@ class SPIRV_TOOLS_EXPORT Optimizer { // interface are considered live and are not eliminated. Optimizer& RegisterLegalizationPasses(); Optimizer& RegisterLegalizationPasses(bool preserve_interface); + Optimizer& RegisterLegalizationPassesFastCompile(); + Optimizer& RegisterLegalizationPassesFastCompile( + bool preserve_interface, bool include_loop_unroll, + SSARewriteMode ssa_rewrite_mode); // Register passes specified in the list of |flags|. Each flag must be a // string of a form accepted by Optimizer::FlagHasValidForm(). @@ -710,6 +723,7 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0); // Only variables that are local to the function and of supported types are // processed (see IsSSATargetVar for details). Optimizer::PassToken CreateSSARewritePass(); +Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode); // Create pass to convert relaxed precision instructions to half precision. // This pass converts as many relaxed float32 arithmetic operations to half as @@ -949,6 +963,12 @@ Optimizer::PassToken CreateFixFuncCallArgumentsPass(); // the unknown capability interacts with one of the trimmed capabilities. Optimizer::PassToken CreateTrimCapabilitiesPass(); +// Creates a pass that trims unused VariablePointers capabilities. +// This pass is intended for targeted call-sites that need to remove stale +// VariablePointers / VariablePointersStorageBuffer declarations left after +// optimization when the final module no longer requires them. +Optimizer::PassToken CreateTrimVariablePointersCapabilitiesPass(); + // Creates a struct-packing pass. // This pass re-assigns all offset layout decorators to tightly pack // the struct with OpName matching `structToPack` according to the given packing diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index a0ca5b84b5..caab0694cd 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -128,6 +128,7 @@ set(SPIRV_TOOLS_OPT_SOURCES switch_descriptorset_pass.h tree_iterator.h trim_capabilities_pass.h + trim_variable_pointers_capabilities_pass.h type_manager.h types.h unify_const_pass.h @@ -249,6 +250,7 @@ set(SPIRV_TOOLS_OPT_SOURCES struct_packing_pass.cpp switch_descriptorset_pass.cpp trim_capabilities_pass.cpp + trim_variable_pointers_capabilities_pass.cpp type_manager.cpp types.cpp unify_const_pass.cpp diff --git a/source/opt/dead_variable_elimination.cpp b/source/opt/dead_variable_elimination.cpp index e39132c22d..507ac6e6b5 100644 --- a/source/opt/dead_variable_elimination.cpp +++ b/source/opt/dead_variable_elimination.cpp @@ -14,6 +14,7 @@ #include "source/opt/dead_variable_elimination.h" +#include #include #include "source/opt/ir_context.h" @@ -77,9 +78,43 @@ Pass::Status DeadVariableElimination::Process() { DeleteVariable(result_id); } } + + ids_to_remove.clear(); + for (auto& function : *get_module()) { + if (function.IsDeclaration()) continue; + + auto& entry = *function.begin(); + for (auto inst = entry.begin(); inst != entry.end(); ++inst) { + if (inst->opcode() != spv::Op::OpVariable) break; + if (!IsFunctionLocalVariable(&*inst)) continue; + if (IsLiveVar(inst->result_id())) continue; + ids_to_remove.push_back(inst->result_id()); + } + } + + if (!ids_to_remove.empty()) { + modified = true; + for (auto result_id : ids_to_remove) { + DeleteLocalVariable(result_id); + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } +bool DeadVariableElimination::IsFunctionLocalVariable( + const Instruction* inst) const { + if (inst->opcode() != spv::Op::OpVariable) return false; + + const Instruction* type_inst = get_def_use_mgr()->GetDef(inst->type_id()); + if (type_inst == nullptr || type_inst->opcode() != spv::Op::OpTypePointer) { + return false; + } + + return spv::StorageClass(type_inst->GetSingleWordInOperand(0)) == + spv::StorageClass::Function; +} + void DeadVariableElimination::DeleteVariable(uint32_t result_id) { Instruction* inst = get_def_use_mgr()->GetDef(result_id); assert(inst->opcode() == spv::Op::OpVariable && @@ -108,5 +143,19 @@ void DeadVariableElimination::DeleteVariable(uint32_t result_id) { } context()->KillDef(result_id); } + +void DeadVariableElimination::DeleteLocalVariable(uint32_t result_id) { + std::queue dead_stores; + std::unordered_set processed; + AddStores(result_id, &dead_stores); + while (!dead_stores.empty()) { + Instruction* inst = dead_stores.front(); + dead_stores.pop(); + if (!processed.insert(inst).second) continue; + DCEInst(inst, nullptr); + } + + context()->KillDef(result_id); +} } // namespace opt } // namespace spvtools diff --git a/source/opt/dead_variable_elimination.h b/source/opt/dead_variable_elimination.h index 5dde71ba79..2b4a71d986 100644 --- a/source/opt/dead_variable_elimination.h +++ b/source/opt/dead_variable_elimination.h @@ -37,6 +37,8 @@ class DeadVariableElimination : public MemPass { private: // Deletes the OpVariable instruction who result id is |result_id|. void DeleteVariable(uint32_t result_id); + void DeleteLocalVariable(uint32_t result_id); + bool IsFunctionLocalVariable(const Instruction* inst) const; // Keeps track of the number of references of an id. Once that value is 0, it // is safe to remove the corresponding instruction. diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index df35401ebc..7aa8ea43b7 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -305,6 +305,18 @@ bool LocalSingleStoreElimPass::RewriteLoads( else stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx); + const auto get_image_pointer_id = [this](uint32_t value_id) { + Instruction* value_inst = context()->get_def_use_mgr()->GetDef(value_id); + while (value_inst && value_inst->opcode() == spv::Op::OpCopyObject) { + value_id = value_inst->GetSingleWordInOperand(0); + value_inst = context()->get_def_use_mgr()->GetDef(value_id); + } + if (!value_inst || value_inst->opcode() != spv::Op::OpLoad) { + return uint32_t{0}; + } + return value_inst->GetSingleWordInOperand(0); + }; + *all_rewritten = true; bool modified = false; for (Instruction* use : uses) { @@ -319,6 +331,17 @@ bool LocalSingleStoreElimPass::RewriteLoads( context()->KillNamesAndDecorates(use->result_id()); context()->ReplaceAllUsesWith(use->result_id(), stored_id); context()->KillInst(use); + } else if (use->opcode() == spv::Op::OpImageTexelPointer && + dominator_analysis->Dominates(store_inst, use)) { + const uint32_t image_ptr_id = get_image_pointer_id(stored_id); + if (image_ptr_id == 0) { + *all_rewritten = false; + continue; + } + modified = true; + context()->ForgetUses(use); + use->SetInOperand(0, {image_ptr_id}); + context()->AnalyzeUses(use); } else { *all_rewritten = false; } diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 4d061ff0c2..4fa06e9dbc 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { } bool MemPass::IsTargetType(const Instruction* typeInst) const { - if (IsBaseTargetType(typeInst)) return true; + switch (ssa_rewrite_mode_) { + case SSARewriteMode::None: + return false; + case SSARewriteMode::OpaqueOnly: + if (typeInst->IsOpaqueType()) return true; + break; + case SSARewriteMode::SpecialTypes: + switch (typeInst->opcode()) { + case spv::Op::OpTypePointer: + case spv::Op::OpTypeUntypedPointerKHR: + case spv::Op::OpTypeCooperativeMatrixNV: + case spv::Op::OpTypeCooperativeMatrixKHR: + return true; + default: + break; + } + break; + case SSARewriteMode::All: + if (IsBaseTargetType(typeInst)) return true; + break; + } if (typeInst->opcode() == spv::Op::OpTypeArray) { if (!IsTargetType( get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) { @@ -198,7 +218,7 @@ bool MemPass::IsLiveVar(uint32_t varId) const { void MemPass::AddStores(uint32_t ptr_id, std::queue* insts) { get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) { spv::Op op = user->opcode(); - if (IsNonPtrAccessChain(op)) { + if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) { AddStores(user->result_id(), insts); } else if (op == spv::Op::OpStore) { insts->push(user); @@ -243,6 +263,9 @@ void MemPass::DCEInst(Instruction* inst, MemPass::MemPass() {} +MemPass::MemPass(SSARewriteMode ssa_rewrite_mode) + : ssa_rewrite_mode_(ssa_rewrite_mode) {} + bool MemPass::HasOnlySupportedRefs(uint32_t varId) { return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { auto dbg_op = user->GetCommonDebugOpcode(); diff --git a/source/opt/mem_pass.h b/source/opt/mem_pass.h index 496286b5f8..d8003f0b60 100644 --- a/source/opt/mem_pass.h +++ b/source/opt/mem_pass.h @@ -25,6 +25,7 @@ #include #include +#include "spirv-tools/optimizer.hpp" #include "source/opt/basic_block.h" #include "source/opt/def_use_manager.h" #include "source/opt/dominator_analysis.h" @@ -69,6 +70,7 @@ class MemPass : public Pass { protected: MemPass(); + explicit MemPass(SSARewriteMode ssa_rewrite_mode); // Returns true if |typeInst| is a scalar type // or a vector or matrix @@ -133,7 +135,9 @@ class MemPass : public Pass { // Cache of verified non-target vars std::unordered_set seen_non_target_vars_; - private: +private: + SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All; + // Return true if all uses of |varId| are only through supported reference // operations ie. loads and store. Also cache in supported_ref_vars_. // TODO(dnovillo): This function is replicated in other passes and it's diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 6986501dba..cf95cdfed1 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -28,6 +28,7 @@ #include "source/opt/log.h" #include "source/opt/pass_manager.h" #include "source/opt/passes.h" +#include "source/opt/trim_variable_pointers_capabilities_pass.h" #include "source/spirv_optimizer_options.h" #include "source/util/make_unique.h" #include "source/util/string_utils.h" @@ -180,6 +181,72 @@ Optimizer& Optimizer::RegisterLegalizationPasses() { return RegisterLegalizationPasses(false); } +Optimizer& Optimizer::RegisterLegalizationPassesFastCompile( + bool preserve_interface, bool include_loop_unroll, + SSARewriteMode ssa_rewrite_mode) { + auto& optimizer = + // Wrap OpKill instructions so all other code can be inlined. + RegisterPass(CreateWrapOpKillPass()) + // Remove unreachable block so that merge return works. + .RegisterPass(CreateDeadBranchElimPass()) + // Merge the returns so we can inline. + .RegisterPass(CreateMergeReturnPass()) + // Make sure uses and definitions are in the same function. + .RegisterPass(CreateInlineExhaustivePass()) + .RegisterPass(CreateEliminateDeadFunctionsPass()); + optimizer.RegisterPass(CreatePrivateToLocalPass()); + // Fix up the storage classes that DXC may have purposely generated + // incorrectly. All functions are inlined, and a lot of dead code has + // been removed. + optimizer.RegisterPass(CreateFixStorageClassPass()); + // Propagate the value stored to the loads in very simple cases. + optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateSSARewritePass(SSARewriteMode::SpecialTypes)); + optimizer + // Split up aggregates so they are easier to deal with. + .RegisterPass(CreateScalarReplacementPass(0)); + // Remove loads and stores so everything is in intermediate values. + // Takes care of copy propagation of non-members. + optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)); + if (ssa_rewrite_mode != SSARewriteMode::None) { + optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode)); + } + optimizer + // Propagate constants to get as many constant conditions on branches + // as possible. + .RegisterPass(CreateCCPPass()); + if (include_loop_unroll) { + optimizer.RegisterPass(CreateLoopUnrollPass(true)); + } + optimizer.RegisterPass(CreateDeadBranchElimPass()) + // Copy propagate members. Cleans up code sequences generated by scalar + // replacement. Also important for removing OpPhi nodes. + .RegisterPass(CreateSimplificationPass()); + return optimizer + // May need loop unrolling here see + // https://github.com/Microsoft/DirectXShaderCompiler/pull/930 + // Get rid of unused code that contain traces of illegal code + // or unused references to unbound external objects + .RegisterPass(CreateVectorDCEPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateReduceLoadSizePass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateDeadVariableEliminationPass()) + .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass()) + .RegisterPass(CreateInterpolateFixupPass()) + .RegisterPass(CreateInvocationInterlockPlacementPass()) + .RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass()); +} + +Optimizer& Optimizer::RegisterLegalizationPassesFastCompile() { + return RegisterLegalizationPassesFastCompile(false, true, + SSARewriteMode::All); +} + Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) { return RegisterPass(CreateWrapOpKillPass()) .RegisterPass(CreateDeadBranchElimPass()) @@ -231,6 +298,57 @@ Optimizer& Optimizer::RegisterPerformancePasses() { return RegisterPerformancePasses(false); } +Optimizer& Optimizer::RegisterPerformancePassesFastCompile( + bool preserve_interface) { + auto& optimizer = RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateDeadVariableEliminationPass()) + .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass()) + .RegisterPass(CreateWrapOpKillPass()) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateMergeReturnPass()) + .RegisterPass(CreateInlineExhaustivePass()) + .RegisterPass(CreateEliminateDeadFunctionsPass()) + .RegisterPass(CreatePrivateToLocalPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateScalarReplacementPass(0)) + .RegisterPass(CreateLocalAccessChainConvertPass()); + optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)); + optimizer.RegisterPass(CreateCCPPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)); + optimizer.RegisterPass(CreateDeadBranchElimPass()); + optimizer.RegisterPass(CreateLocalRedundancyEliminationPass()); + optimizer.RegisterPass(CreateCombineAccessChainsPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateScalarReplacementPass(0)) + .RegisterPass(CreateLocalAccessChainConvertPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateVectorDCEPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateIfConversionPass()) + .RegisterPass(CreateCopyPropagateArraysPass()) + .RegisterPass(CreateReduceLoadSizePass()) + .RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateBlockMergePass()); + optimizer.RegisterPass(CreateLocalRedundancyEliminationPass()); + return optimizer.RegisterPass(CreateAggressiveDCEPass(preserve_interface)) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateBlockMergePass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateTrimVariablePointersCapabilitiesPass()); +} + +Optimizer& Optimizer::RegisterPerformancePassesFastCompile() { + return RegisterPerformancePassesFastCompile(false); +} + Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) { return RegisterPass(CreateWrapOpKillPass()) .RegisterPass(CreateDeadBranchElimPass()) @@ -1024,6 +1142,11 @@ Optimizer::PassToken CreateSSARewritePass() { MakeUnique()); } +Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) { + return MakeUnique( + MakeUnique(mode)); +} + Optimizer::PassToken CreateCopyPropagateArraysPass() { return MakeUnique( MakeUnique()); @@ -1175,6 +1298,11 @@ Optimizer::PassToken CreateTrimCapabilitiesPass() { MakeUnique()); } +Optimizer::PassToken CreateTrimVariablePointersCapabilitiesPass() { + return MakeUnique( + MakeUnique()); +} + Optimizer::PassToken CreateStructPackingPass(const char* structToPack, const char* packingRule) { return MakeUnique( diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index cdda3804b3..3495179a80 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -467,6 +467,21 @@ void ScalarReplacementPass::TransferAnnotations( } } +bool ScalarReplacementPass::IsPhysicalStorageBufferPointerVariable( + const Instruction* var_inst) const { + if (var_inst->opcode() != spv::Op::OpVariable) { + return false; + } + + Instruction* storage_type = GetStorageType(var_inst); + if (storage_type->opcode() != spv::Op::OpTypePointer) { + return false; + } + + return spv::StorageClass(storage_type->GetSingleWordInOperand(0)) == + spv::StorageClass::PhysicalStorageBuffer; +} + void ScalarReplacementPass::CreateVariable( uint32_t type_id, Instruction* var_inst, uint32_t index, std::vector* replacements) { @@ -987,6 +1002,7 @@ void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from, void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable( Instruction* from, Instruction* to, uint32_t member_index) { Instruction* type_inst = GetStorageType(from); + std::vector decorations_to_kill; for (auto dec_inst : get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) { uint32_t decoration; @@ -1002,7 +1018,17 @@ void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable( case spv::Decoration::AlignmentId: case spv::Decoration::MaxByteOffset: case spv::Decoration::MaxByteOffsetId: - case spv::Decoration::RelaxedPrecision: { + case spv::Decoration::RelaxedPrecision: + case spv::Decoration::AliasedPointer: + case spv::Decoration::RestrictPointer: { + if ((decoration == uint32_t(spv::Decoration::AliasedPointer) || + decoration == uint32_t(spv::Decoration::RestrictPointer)) && + (!IsPhysicalStorageBufferPointerVariable(to) || + get_decoration_mgr()->HasDecoration( + to->result_id(), static_cast(decoration)))) { + decorations_to_kill.push_back(dec_inst); + break; + } std::unique_ptr new_dec_inst( new Instruction(context(), spv::Op::OpDecorate, 0, 0, {})); new_dec_inst->AddOperand( @@ -1011,12 +1037,19 @@ void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable( new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i))); } context()->AddAnnotationInst(std::move(new_dec_inst)); + if (decoration == uint32_t(spv::Decoration::AliasedPointer) || + decoration == uint32_t(spv::Decoration::RestrictPointer)) { + decorations_to_kill.push_back(dec_inst); + } } break; default: break; } } } + for (auto* decoration_inst : decorations_to_kill) { + context()->KillInst(decoration_inst); + } } } // namespace opt diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index 77d5bd5f1a..8745b9a120 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h @@ -123,6 +123,10 @@ class ScalarReplacementPass : public MemPass { void TransferAnnotations(const Instruction* source, std::vector* replacements); + // Returns true if |var_inst| stores a PhysicalStorageBuffer pointer. + bool IsPhysicalStorageBufferPointerVariable( + const Instruction* var_inst) const; + // Scalarizes |inst| and updates its uses. // // |inst| must be an OpVariable. It is replaced with an OpVariable for each diff --git a/source/opt/ssa_rewrite_pass.h b/source/opt/ssa_rewrite_pass.h index 076d9e1651..47aab7f27f 100644 --- a/source/opt/ssa_rewrite_pass.h +++ b/source/opt/ssa_rewrite_pass.h @@ -295,6 +295,7 @@ class SSARewriter { class SSARewritePass : public MemPass { public: SSARewritePass() = default; + explicit SSARewritePass(SSARewriteMode mode) : MemPass(mode) {} const char* name() const override { return "ssa-rewrite"; } Status Process() override; diff --git a/source/opt/trim_variable_pointers_capabilities_pass.cpp b/source/opt/trim_variable_pointers_capabilities_pass.cpp new file mode 100644 index 0000000000..5db11eaa78 --- /dev/null +++ b/source/opt/trim_variable_pointers_capabilities_pass.cpp @@ -0,0 +1,343 @@ +// Copyright (c) 2026 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/trim_variable_pointers_capabilities_pass.h" + +#include +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +namespace { + +constexpr uint32_t kOpTypePointerStorageClassIndex = 0; +constexpr uint32_t kTypePointerTypeIdInIndex = 1; + +struct RequiredVariablePointerCapabilities { + bool variable_pointers = false; + bool variable_pointers_storage_buffer = false; + + void Add(spv::Capability capability) { + switch (capability) { + case spv::Capability::VariablePointers: + variable_pointers = true; + break; + case spv::Capability::VariablePointersStorageBuffer: + variable_pointers_storage_buffer = true; + break; + default: + break; + } + } +}; + +template +void DFSWhile(const Instruction* instruction, UnaryPredicate condition) { + std::stack instructions_to_visit; + std::unordered_set visited_instructions; + instructions_to_visit.push(instruction->result_id()); + const auto* def_use_mgr = instruction->context()->get_def_use_mgr(); + + while (!instructions_to_visit.empty()) { + const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top()); + instructions_to_visit.pop(); + + if (item == nullptr) { + continue; + } + + if (visited_instructions.count(item->result_id()) != 0) { + continue; + } + visited_instructions.insert(item->result_id()); + + if (!condition(item)) { + continue; + } + + if (item->opcode() == spv::Op::OpTypePointer) { + instructions_to_visit.push( + item->GetSingleWordInOperand(kTypePointerTypeIdInIndex)); + continue; + } + + if (item->opcode() == spv::Op::OpTypeMatrix || + item->opcode() == spv::Op::OpTypeVector || + item->opcode() == spv::Op::OpTypeArray || + item->opcode() == spv::Op::OpTypeRuntimeArray) { + instructions_to_visit.push(item->GetSingleWordInOperand(0)); + continue; + } + + if (item->opcode() == spv::Op::OpTypeStruct) { + item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) { + instructions_to_visit.push(*op_id); + }); + } + } +} + +template +bool AnyTypeOf(const Instruction* instruction, UnaryPredicate predicate) { + if (instruction == nullptr || !IsTypeInst(instruction->opcode())) { + return false; + } + + bool found_one = false; + DFSWhile(instruction, [&found_one, predicate](const Instruction* node) { + if (found_one || predicate(node)) { + found_one = true; + return false; + } + + return true; + }); + return found_one; +} + +std::optional GetVariablePointerCapability( + spv::StorageClass storage_class) { + switch (storage_class) { + case spv::StorageClass::StorageBuffer: + return spv::Capability::VariablePointersStorageBuffer; + case spv::StorageClass::Workgroup: + return spv::Capability::VariablePointers; + default: + return std::nullopt; + } +} + +std::optional GetLogicalPointerStorageClass( + const Instruction* type_instruction) { + if (type_instruction == nullptr) { + return std::nullopt; + } + + if (type_instruction->opcode() != spv::Op::OpTypePointer && + type_instruction->opcode() != spv::Op::OpTypeUntypedPointerKHR) { + return std::nullopt; + } + + const auto storage_class = static_cast( + type_instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex)); + if (storage_class == spv::StorageClass::PhysicalStorageBuffer) { + return std::nullopt; + } + + return storage_class; +} + +std::optional GetLogicalPointerResultStorageClass( + const Instruction* instruction) { + if (instruction == nullptr || instruction->type_id() == 0) { + return std::nullopt; + } + + return GetLogicalPointerStorageClass( + instruction->context()->get_def_use_mgr()->GetDef(instruction->type_id())); +} + +void AddCapabilityForStorageClass( + std::optional storage_class, + RequiredVariablePointerCapabilities* required_capabilities) { + if (!storage_class) { + return; + } + + if (const auto capability = GetVariablePointerCapability(*storage_class)) { + required_capabilities->Add(*capability); + } +} + +void AddVariablePointerCapabilityForResult( + const Instruction* instruction, + RequiredVariablePointerCapabilities* required_capabilities) { + AddCapabilityForStorageClass( + GetLogicalPointerResultStorageClass(instruction), required_capabilities); +} + +void AddVariablePointerCapabilitiesFromPointerOperands( + const Instruction* instruction, + RequiredVariablePointerCapabilities* required_capabilities) { + instruction->ForEachInId([instruction, required_capabilities](const uint32_t* id) { + const auto* operand_instruction = + instruction->context()->get_def_use_mgr()->GetDef(*id); + AddCapabilityForStorageClass( + GetLogicalPointerResultStorageClass(operand_instruction), + required_capabilities); + }); +} + +void AddVariablePointerCapabilitiesFromAllocatedType( + const Instruction* instruction, + RequiredVariablePointerCapabilities* required_capabilities) { + if (instruction->type_id() == 0) { + return; + } + + const auto* variable_type = + instruction->context()->get_def_use_mgr()->GetDef(instruction->type_id()); + if (variable_type == nullptr || + (variable_type->opcode() != spv::Op::OpTypePointer && + variable_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) { + return; + } + + if (variable_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) { + AddCapabilityForStorageClass(GetLogicalPointerStorageClass(variable_type), + required_capabilities); + return; + } + + const auto* pointee_type = + instruction->context()->get_def_use_mgr()->GetDef( + variable_type->GetSingleWordInOperand(kTypePointerTypeIdInIndex)); + if (pointee_type == nullptr) { + return; + } + + if (AnyTypeOf(pointee_type, [](const Instruction* type_instruction) { + return GetLogicalPointerStorageClass(type_instruction) == + spv::StorageClass::StorageBuffer; + })) { + required_capabilities->Add( + spv::Capability::VariablePointersStorageBuffer); + } + + if (AnyTypeOf(pointee_type, [](const Instruction* type_instruction) { + return GetLogicalPointerStorageClass(type_instruction) == + spv::StorageClass::Workgroup; + })) { + required_capabilities->Add(spv::Capability::VariablePointers); + } +} + +void AddVariablePointerCapabilityRequirements( + const Instruction* instruction, + RequiredVariablePointerCapabilities* required_capabilities) { + switch (instruction->opcode()) { + case spv::Op::OpSelect: + case spv::Op::OpPhi: + case spv::Op::OpFunctionCall: + case spv::Op::OpPtrAccessChain: + case spv::Op::OpLoad: + case spv::Op::OpConstantNull: + case spv::Op::OpFunction: + case spv::Op::OpFunctionParameter: + case spv::Op::OpUntypedPtrAccessChainKHR: + case spv::Op::OpUntypedInBoundsPtrAccessChainKHR: + AddVariablePointerCapabilityForResult(instruction, required_capabilities); + break; + default: + break; + } + + switch (instruction->opcode()) { + case spv::Op::OpReturnValue: + case spv::Op::OpStore: + case spv::Op::OpPtrAccessChain: + case spv::Op::OpPtrEqual: + case spv::Op::OpPtrNotEqual: + case spv::Op::OpPtrDiff: + case spv::Op::OpSelect: + case spv::Op::OpPhi: + case spv::Op::OpVariable: + case spv::Op::OpFunctionCall: + case spv::Op::OpUntypedPtrAccessChainKHR: + case spv::Op::OpUntypedInBoundsPtrAccessChainKHR: + AddVariablePointerCapabilitiesFromPointerOperands( + instruction, required_capabilities); + break; + default: + break; + } + + switch (instruction->opcode()) { + case spv::Op::OpVariable: + case spv::Op::OpUntypedVariableKHR: + AddVariablePointerCapabilitiesFromAllocatedType( + instruction, required_capabilities); + break; + default: + break; + } +} + +bool CanRemoveVariablePointers( + const RequiredVariablePointerCapabilities& required_capabilities, + bool has_explicit_variable_pointers, + bool has_explicit_variable_pointers_storage_buffer) { + return has_explicit_variable_pointers && + !required_capabilities.variable_pointers && + (!required_capabilities.variable_pointers_storage_buffer || + has_explicit_variable_pointers_storage_buffer); +} + +bool CanRemoveVariablePointersStorageBuffer( + const RequiredVariablePointerCapabilities& required_capabilities, + bool has_explicit_variable_pointers, + bool has_explicit_variable_pointers_storage_buffer) { + return has_explicit_variable_pointers_storage_buffer && + !required_capabilities.variable_pointers_storage_buffer && + (!required_capabilities.variable_pointers || + has_explicit_variable_pointers); +} + +} // namespace + +Pass::Status TrimVariablePointersCapabilitiesPass::Process() { + const bool has_explicit_variable_pointers = get_module()->HasExplicitCapability( + static_cast(spv::Capability::VariablePointers)); + const bool has_explicit_variable_pointers_storage_buffer = + get_module()->HasExplicitCapability( + static_cast( + spv::Capability::VariablePointersStorageBuffer)); + + if (!has_explicit_variable_pointers && + !has_explicit_variable_pointers_storage_buffer) { + return Status::SuccessWithoutChange; + } + + RequiredVariablePointerCapabilities required_capabilities; + get_module()->ForEachInst( + [&required_capabilities](Instruction* instruction) { + AddVariablePointerCapabilityRequirements(instruction, + &required_capabilities); + }, + true); + + bool modified = false; + if (CanRemoveVariablePointers(required_capabilities, + has_explicit_variable_pointers, + has_explicit_variable_pointers_storage_buffer)) { + context()->RemoveCapability(spv::Capability::VariablePointers); + modified = true; + } + + if (CanRemoveVariablePointersStorageBuffer( + required_capabilities, has_explicit_variable_pointers, + has_explicit_variable_pointers_storage_buffer)) { + context()->RemoveCapability(spv::Capability::VariablePointersStorageBuffer); + modified = true; + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/trim_variable_pointers_capabilities_pass.h b/source/opt/trim_variable_pointers_capabilities_pass.h new file mode 100644 index 0000000000..550fe7efc4 --- /dev/null +++ b/source/opt/trim_variable_pointers_capabilities_pass.h @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_TRIM_VARIABLE_POINTERS_CAPABILITIES_PASS_H_ +#define SOURCE_OPT_TRIM_VARIABLE_POINTERS_CAPABILITIES_PASS_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +class TrimVariablePointersCapabilitiesPass : public Pass { + public: + TrimVariablePointersCapabilitiesPass() = default; + TrimVariablePointersCapabilitiesPass( + const TrimVariablePointersCapabilitiesPass&) = delete; + TrimVariablePointersCapabilitiesPass( + TrimVariablePointersCapabilitiesPass&&) = delete; + + const char* name() const override { + return "trim-variable-pointers-capabilities"; + } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_TRIM_VARIABLE_POINTERS_CAPABILITIES_PASS_H_ diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index e2fdccf456..84d5bd29ec 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -110,6 +110,7 @@ add_spvtools_unittest(TARGET opt struct_packing_test.cpp switch_descriptorset_test.cpp trim_capabilities_pass_test.cpp + trim_variable_pointers_capabilities_pass_test.cpp type_manager_test.cpp types_test.cpp unify_const_test.cpp diff --git a/test/opt/dead_variable_elim_test.cpp b/test/opt/dead_variable_elim_test.cpp index a55ee62a55..c6de1c7177 100644 --- a/test/opt/dead_variable_elim_test.cpp +++ b/test/opt/dead_variable_elim_test.cpp @@ -293,6 +293,50 @@ OpFunctionEnd SinglePassRunAndCheck(before, after, true, true); } +TEST_F(DeadVariableElimTest, RemoveDeadFunctionLocalWithDeadStore) { + const std::string before = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpName %main "main" +OpName %dead "dead" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float +%buffer = OpVariable %_ptr_StorageBuffer_float StorageBuffer +%_ptr_Function__ptr_StorageBuffer_float = OpTypePointer Function %_ptr_StorageBuffer_float +%main = OpFunction %void None %3 +%entry = OpLabel +%dead = OpVariable %_ptr_Function__ptr_StorageBuffer_float Function +OpStore %dead %buffer +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpName %main "main" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float +%buffer = OpVariable %_ptr_StorageBuffer_float StorageBuffer +%_ptr_Function__ptr_StorageBuffer_float = OpTypePointer Function %_ptr_StorageBuffer_float +%main = OpFunction %void None %3 +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + } // namespace } // namespace opt } // namespace spvtools diff --git a/test/opt/local_single_store_elim_test.cpp b/test/opt/local_single_store_elim_test.cpp index 8fd5c9d2f1..8f4fd7f888 100644 --- a/test/opt/local_single_store_elim_test.cpp +++ b/test/opt/local_single_store_elim_test.cpp @@ -908,6 +908,82 @@ OpFunctionEnd SinglePassRunAndCheck(before, after, true, true); } +TEST_F(LocalSingleStoreElimTest, RewriteImageTexelPointerImageOperand) { + const std::string before = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %g_rwTexture3d +OpExecutionMode %main LocalSize 256 1 1 +OpSource HLSL 660 +OpName %type_3d_image "type.3d.image" +OpName %g_rwTexture3d "g_rwTexture3d" +OpName %main "main" +OpDecorate %g_rwTexture3d DescriptorSet 0 +OpDecorate %g_rwTexture3d Binding 0 +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%v3uint = OpTypeVector %uint 3 +%10 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3 +%type_3d_image = OpTypeImage %uint 3D 2 0 0 2 R32ui +%_ptr_UniformConstant_type_3d_image = OpTypePointer UniformConstant %type_3d_image +%void = OpTypeVoid +%13 = OpTypeFunction %void +%_ptr_Function_type_3d_image = OpTypePointer Function %type_3d_image +%_ptr_Image_uint = OpTypePointer Image %uint +%g_rwTexture3d = OpVariable %_ptr_UniformConstant_type_3d_image UniformConstant +%main = OpFunction %void None %13 +%16 = OpLabel +%17 = OpVariable %_ptr_Function_type_3d_image Function +%18 = OpLoad %type_3d_image %g_rwTexture3d +OpStore %17 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %17 %10 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + const std::string after = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %g_rwTexture3d +OpExecutionMode %main LocalSize 256 1 1 +OpSource HLSL 660 +OpName %type_3d_image "type.3d.image" +OpName %g_rwTexture3d "g_rwTexture3d" +OpName %main "main" +OpDecorate %g_rwTexture3d DescriptorSet 0 +OpDecorate %g_rwTexture3d Binding 0 +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%v3uint = OpTypeVector %uint 3 +%10 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3 +%type_3d_image = OpTypeImage %uint 3D 2 0 0 2 R32ui +%_ptr_UniformConstant_type_3d_image = OpTypePointer UniformConstant %type_3d_image +%void = OpTypeVoid +%13 = OpTypeFunction %void +%_ptr_Function_type_3d_image = OpTypePointer Function %type_3d_image +%_ptr_Image_uint = OpTypePointer Image %uint +%g_rwTexture3d = OpVariable %_ptr_UniformConstant_type_3d_image UniformConstant +%main = OpFunction %void None %13 +%16 = OpLabel +%17 = OpVariable %_ptr_Function_type_3d_image Function +%18 = OpLoad %type_3d_image %g_rwTexture3d +OpStore %17 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %g_rwTexture3d %10 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetTargetEnv(SPV_ENV_UNIVERSAL_1_4); + SinglePassRunAndCheck(before, after, true, true); +} + // Test that that an unused OpAccessChain between a store and a use does does // not hinders the replacement of the use. We need to check this because // local-access-chain-convert does always remove the OpAccessChain instructions diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp index 0256ac264f..96aee4ac88 100644 --- a/test/opt/scalar_replacement_test.cpp +++ b/test/opt/scalar_replacement_test.cpp @@ -2357,6 +2357,46 @@ TEST_F(ScalarReplacementTest, RestrictPointer) { SinglePassRunAndMatch(text, true); } +TEST_F(ScalarReplacementTest, RestrictPointerMember) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate [[struct_type:%\w+]] 0 RestrictPointer +; CHECK: OpDecorate [[new_var:%\w+]] RestrictPointer +; CHECK: [[struct_type]] = OpTypeStruct [[ptr_type:%\w+]] +; CHECK: [[ptr_type]] = OpTypePointer PhysicalStorageBuffer [[block_type:%\w+]] +; CHECK: [[var_type:%\w+]] = OpTypePointer Function [[ptr_type]] +; CHECK: [[new_var]] = OpVariable [[var_type]] Function + OpCapability Shader + OpCapability PhysicalStorageBufferAddresses + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel PhysicalStorageBuffer64 GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpMemberDecorate %3 0 Offset 0 + OpDecorate %3 Block + OpMemberDecorate %11 0 RestrictPointer + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpConstant %8 0 + %3 = OpTypeStruct %8 + %10 = OpTypePointer PhysicalStorageBuffer %3 + %11 = OpTypeStruct %10 + %13 = OpTypePointer Function %11 + %14 = OpTypePointer Function %10 + %16 = OpUndef %11 + %2 = OpFunction %6 None %7 + %17 = OpLabel + %5 = OpVariable %13 Function + OpStore %5 %16 + %18 = OpAccessChain %14 %5 %9 + OpReturn + OpFunctionEnd + )"; + + SetTargetEnv(SPV_ENV_UNIVERSAL_1_6); + SinglePassRunAndMatch(text, true); +} + } // namespace } // namespace opt } // namespace spvtools diff --git a/test/opt/trim_variable_pointers_capabilities_pass_test.cpp b/test/opt/trim_variable_pointers_capabilities_pass_test.cpp new file mode 100644 index 0000000000..f325d6385b --- /dev/null +++ b/test/opt/trim_variable_pointers_capabilities_pass_test.cpp @@ -0,0 +1,222 @@ +// Copyright (c) 2026 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "source/opt/trim_variable_pointers_capabilities_pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using TrimVariablePointersCapabilitiesPassTest = PassTest<::testing::Test>; + +TEST_F(TrimVariablePointersCapabilitiesPassTest, + VariablePointers_RemovedWhenUnused) { + const std::string kTest = R"( + OpCapability Shader + OpCapability VariablePointers +; CHECK: OpCapability Shader +; CHECK-NOT: OpCapability VariablePointers +; CHECK-NOT: OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %buf + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_struct_4 Block + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %buf DescriptorSet 0 + OpDecorate %buf Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_4 = OpTypeStruct %uint +%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %buf = OpVariable %_ptr_StorageBuffer__struct_4 StorageBuffer + %main = OpFunction %void None %3 + %9 = OpLabel + %10 = OpAccessChain %_ptr_StorageBuffer_uint %buf %uint_0 %uint_0 + %11 = OpLoad %uint %10 + OpReturn + OpFunctionEnd + )"; + const auto result = + SinglePassRunAndMatch( + kTest, /* skip_nop= */ false); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange); +} + +TEST_F(TrimVariablePointersCapabilitiesPassTest, + VariablePointers_RemainsForWorkgroupSelect) { + const std::string kTest = R"( + OpCapability Shader + OpCapability VariablePointers +; CHECK: OpCapability Shader +; CHECK: OpCapability VariablePointers + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %var = OpVariable %_ptr_Workgroup_uint Workgroup + %main = OpFunction %void None %3 + %8 = OpLabel + %9 = OpSelect %_ptr_Workgroup_uint %true %var %var + OpReturn + OpFunctionEnd + )"; + const auto result = + SinglePassRunAndMatch( + kTest, /* skip_nop= */ false); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange); +} + +TEST_F(TrimVariablePointersCapabilitiesPassTest, + VariablePointers_RemainsWhenItIsTheOnlyExplicitCapabilityForStorageBuffer) { + const std::string kTest = R"( + OpCapability Shader + OpCapability VariablePointers +; CHECK: OpCapability Shader +; CHECK: OpCapability VariablePointers +; CHECK-NOT: OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %buf + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_struct_4 Block + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %buf DescriptorSet 0 + OpDecorate %buf Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_4 = OpTypeStruct %uint +%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %8 = OpTypeFunction %uint %_ptr_StorageBuffer_uint + %buf = OpVariable %_ptr_StorageBuffer__struct_4 StorageBuffer + %main = OpFunction %void None %3 + %10 = OpLabel + %11 = OpAccessChain %_ptr_StorageBuffer_uint %buf %uint_0 %uint_0 + %12 = OpFunctionCall %uint %callee %11 + OpReturn + OpFunctionEnd + %callee = OpFunction %uint None %8 + %15 = OpFunctionParameter %_ptr_StorageBuffer_uint + %16 = OpLabel + OpReturnValue %uint_0 + OpFunctionEnd + )"; + const auto result = + SinglePassRunAndMatch( + kTest, /* skip_nop= */ false); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange); +} + +TEST_F(TrimVariablePointersCapabilitiesPassTest, + VariablePointersStorageBuffer_RemainsForFunctionCallParameter) { + const std::string kTest = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer +; CHECK: OpCapability Shader +; CHECK: OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %buf + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_struct_4 Block + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %buf DescriptorSet 0 + OpDecorate %buf Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_4 = OpTypeStruct %uint +%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %8 = OpTypeFunction %uint %_ptr_StorageBuffer_uint + %buf = OpVariable %_ptr_StorageBuffer__struct_4 StorageBuffer + %main = OpFunction %void None %3 + %10 = OpLabel + %11 = OpAccessChain %_ptr_StorageBuffer_uint %buf %uint_0 %uint_0 + %12 = OpFunctionCall %uint %callee %11 + OpReturn + OpFunctionEnd + %callee = OpFunction %uint None %8 + %15 = OpFunctionParameter %_ptr_StorageBuffer_uint + %16 = OpLabel + OpReturnValue %uint_0 + OpFunctionEnd + )"; + const auto result = + SinglePassRunAndMatch( + kTest, /* skip_nop= */ false); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange); +} + +TEST_F(TrimVariablePointersCapabilitiesPassTest, + VariablePointers_RemovedWhenStorageBufferCapabilityIsAlsoDeclared) { + const std::string kTest = R"( + OpCapability Shader + OpCapability VariablePointers + OpCapability VariablePointersStorageBuffer +; CHECK: OpCapability Shader +; CHECK-NOT: OpCapability VariablePointers +; CHECK: OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %buf + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_struct_4 Block + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %buf DescriptorSet 0 + OpDecorate %buf Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_4 = OpTypeStruct %uint +%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %8 = OpTypeFunction %uint %_ptr_StorageBuffer_uint + %buf = OpVariable %_ptr_StorageBuffer__struct_4 StorageBuffer + %main = OpFunction %void None %3 + %10 = OpLabel + %11 = OpAccessChain %_ptr_StorageBuffer_uint %buf %uint_0 %uint_0 + %12 = OpFunctionCall %uint %callee %11 + OpReturn + OpFunctionEnd + %callee = OpFunction %uint None %8 + %15 = OpFunctionParameter %_ptr_StorageBuffer_uint + %16 = OpLabel + OpReturnValue %uint_0 + OpFunctionEnd + )"; + const auto result = + SinglePassRunAndMatch( + kTest, /* skip_nop= */ false); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange); +} + +} // namespace +} // namespace opt +} // namespace spvtools