From 9bff1d3ba879562575236ec480557a00348606d1 Mon Sep 17 00:00:00 2001 From: Helena Kotas Date: Fri, 3 Apr 2026 13:56:21 -0700 Subject: [PATCH] [SM6.10][Bugfix] Update vector sizes in linalg.h to match column-vector multiplication Also fixes an issue with `__builtin_LinAlg_MatrixVectorMultiply*` built-ins that did not allow vectors of different sizes. --- tools/clang/lib/Headers/hlsl/dx/linalg.h | 42 +++++++++---------- .../CodeGenDXIL/hlsl/linalg/api/vectors.hlsl | 38 +++++++++-------- utils/hct/gen_intrin_main.txt | 4 +- 3 files changed, 43 insertions(+), 41 deletions(-) diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h index 2a636662cd..d212d86017 100644 --- a/tools/clang/lib/Headers/hlsl/dx/linalg.h +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -465,7 +465,7 @@ Matrix Multiply( template // clang-format off -typename hlsl::enable_if::value, vector >::type +typename hlsl::enable_if::value, vector >::type // clang-format on Multiply(Matrix MatrixA, vector Vec) { @@ -479,11 +479,11 @@ Multiply(Matrix MatrixA, template // clang-format off -typename hlsl::enable_if::value, vector >::type +typename hlsl::enable_if::value, vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, - vector Vec, vector Bias) { - vector Result; + vector Vec, vector Bias) { + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, hlsl::is_signed::value, Vec, MatrixDT, Bias, MatrixDT); @@ -491,17 +491,17 @@ MultiplyAdd(Matrix MatrixA, } template // clang-format off typename hlsl::enable_if< - InterpretedVector::Size == M, - vector >::type + InterpretedVector::Size == K, + vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, - InterpretedVector InterpVec, - vector Bias) { - vector Result; + InterpretedVector InterpVec, + vector Bias) { + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd( Result, MatrixA.__handle, hlsl::is_signed::value, InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT); @@ -512,14 +512,14 @@ template // clang-format off typename hlsl::enable_if::value, - vector >::type + vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, - vector Vec, VectorRef BiasRef) { + vector Vec, VectorRef BiasRef) { using BiasVecTy = - vector::Type, K>; + vector::Type, M>; BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); - vector Result; + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, hlsl::is_signed::value, Vec, MatrixDT, BiasVec, BiasElTy); @@ -527,20 +527,20 @@ MultiplyAdd(Matrix MatrixA, } template // clang-format off typename hlsl::enable_if< - InterpretedVector::Size == M, - vector >::type + InterpretedVector::Size == K, + vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, - InterpretedVector InterpVec, - VectorRef BiasRef) { + InterpretedVector InterpVec, + VectorRef BiasRef) { using BiasVecTy = - vector::Type, K>; + vector::Type, M>; BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); - vector Result; + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd( Result, MatrixA.__handle, hlsl::is_signed::value, InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy); diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl index 98332c9ed9..876edcca46 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl @@ -4,7 +4,7 @@ #include using namespace dx::linalg; -using MatrixATy = Matrix; +using MatrixATy = Matrix; using MatrixAccumTy = Matrix; ByteAddressBuffer BAB : register(t0); @@ -12,29 +12,31 @@ ByteAddressBuffer BAB : register(t0); [numthreads(4, 4, 4)] void main(uint ID : SV_GroupID) { -// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N8U0S0( +// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0( // CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2) // CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) MatrixATy Mat1 = MatrixATy::Load(BAB, 0, 8); - vector vec1 = 10.3f; + vector vec1 = 10.3f; -// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N8U0S0.v8f16(i32 -2147483623, -// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, i1 true, <8 x half> , i32 8) -// CHECK-SAME: ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation) +// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> , i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation) vector vec2 = Multiply(Mat1, vec1); -// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622, -// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> , i32 8, <8 x half> %[[VEC2]], i32 8) +// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> , i32 8, <8 x half> %[[VEC2]], i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec3 = MultiplyAdd(Mat1, vec1, vec2); -// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622, -// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8) +// CHECK: %[[VEC20:.*]] = shufflevector + vector vec20 = (vector)vec2; + +// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x half> %[[VEC3]], i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) - InterpretedVector interpVec2 = MakeInterpretedVector(vec2); + InterpretedVector interpVec2 = MakeInterpretedVector(vec20); vector vec4 = MultiplyAdd(Mat1, interpVec2, vec3); // CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303, @@ -42,11 +44,11 @@ void main(uint ID : SV_GroupID) { // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 - // CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) // CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) VectorRef memBias = {BAB, 4096}; - vector vec5 = MultiplyAdd(Mat1, vec3, memBias); + vector vec5 = MultiplyAdd(Mat1, interpVec2, memBias); // CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303, // CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) @@ -54,8 +56,8 @@ void main(uint ID : SV_GroupID) { // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 - // CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec6 = MultiplyAdd(Mat1, interpVec2, memBias); diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index 239f381614..3075e8a4cf 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -404,8 +404,8 @@ uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp, in numeric<> bias, in uint biasInterp); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric input, in uint inputInterp); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric input, in uint inputInterp, in numeric bias, in uint biasInterp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);