diff --git a/include/matx/operators/hermitian.h b/include/matx/operators/hermitian.h index 0d7b602dc..16e71b94f 100644 --- a/include/matx/operators/hermitian.h +++ b/include/matx/operators/hermitian.h @@ -112,6 +112,8 @@ namespace matx #endif __MATX_INLINE__ std::string str() const { return "hermitian(" + op_.str() + ")"; } + __MATX_INLINE__ const auto& Input() const noexcept { return op_; } + __MATX_INLINE__ HermitianTransOp(const T1 &op) : op_(op) { MATX_LOG_TRACE("{} constructor: rank={}", str(), Rank()); static_assert(Rank() >= 2, "Hermitian operation needs input with rank >= 2"); diff --git a/include/matx/operators/unary_operators.h b/include/matx/operators/unary_operators.h index da5bb5bcc..edbbee10c 100644 --- a/include/matx/operators/unary_operators.h +++ b/include/matx/operators/unary_operators.h @@ -87,6 +87,10 @@ namespace matx return op_.str() + "(" + get_type_str(in1_) + ")"; } + __MATX_INLINE__ const auto& Input() const noexcept { + return in1_; + } + __MATX_INLINE__ matxUnaryOp(const I1 &in1, const Op &op) : in1_(in1), op_(op) { MATX_LOG_TRACE("{} constructor: rank={}", str(), Rank()); if constexpr (Rank() > 0) { diff --git a/include/matx/transforms/matmul/matmul_cuda.h b/include/matx/transforms/matmul/matmul_cuda.h index a9e2b1390..6b5cb4385 100644 --- a/include/matx/transforms/matmul/matmul_cuda.h +++ b/include/matx/transforms/matmul/matmul_cuda.h @@ -47,6 +47,9 @@ #include "matx/core/error.h" #include "matx/core/nvtx.h" #include "matx/core/tensor.h" +#include "matx/operators/hermitian.h" +#include "matx/operators/transpose.h" +#include "matx/operators/unary_operators.h" #include "matx/transforms/matmul/matmul_common.h" namespace matx { @@ -140,11 +143,71 @@ struct MatMulCUDAParams_t { MatXDataType_t dtype; cublasOperation_t opA; cublasOperation_t opB; + MemOrder_t orderA = MEM_ORDER_ROW_MAJOR; + MemOrder_t orderB = MEM_ORDER_ROW_MAJOR; + MemOrder_t orderC = MEM_ORDER_ROW_MAJOR; bool a_planar = false; bool b_planar = false; bool c_planar = false; }; +static __MATX_INLINE__ index_t GemmOpRows(cublasOperation_t op, index_t rows, + index_t cols) +{ + return op == CUBLAS_OP_N ? rows : cols; +} + +static __MATX_INLINE__ index_t GemmOpCols(cublasOperation_t op, index_t rows, + index_t cols) +{ + return op == CUBLAS_OP_N ? cols : rows; +} + +template +struct is_hermitian_trans_op : std::false_type { +}; + +template +struct is_hermitian_trans_op> : std::true_type { +}; + +template +inline constexpr bool is_hermitian_trans_op_v = + is_hermitian_trans_op>::value; + +template +struct is_conj_tensor_view_unary_op : std::false_type { +}; + +// cuBLASLt can express conjugation of a tensor view by passing the transposed +// view with CUBLAS_OP_C. Keep this narrow so nested expressions such as +// conj(hermitianT(A)) are evaluated as written by the generic copy path. +template +struct is_conj_tensor_view_unary_op>> + : std::bool_constant>> { +}; + +template +inline constexpr bool is_conj_tensor_view_unary_op_v = + is_conj_tensor_view_unary_op>::value; + +template +static constexpr cublasOperation_t MatMulConjTransposeOp() +{ + if constexpr (is_complex_v && !is_complex_half_v) { + return CUBLAS_OP_C; + } + else { + return CUBLAS_OP_T; + } +} + +template +static constexpr bool CanUseCublasLtConjTransposeOp() +{ + return !is_complex_half_v; +} + template class MatMulCUDAHandle_t { @@ -188,14 +251,13 @@ class MatMulCUDAHandle_t { * */ MatMulCUDAHandle_t(TensorTypeC &c, const TensorTypeA &a, - const TensorTypeB &b) + const TensorTypeB &b, + cublasOperation_t opA = CUBLAS_OP_N, + cublasOperation_t opB = CUBLAS_OP_N) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) static_assert(RANK >= 2); - MATX_ASSERT(a.Size(TensorTypeA::Rank() - 1) == b.Size(TensorTypeB::Rank() - 2), matxInvalidSize); - MATX_ASSERT(c.Size(RANK - 1) == b.Size(TensorTypeB::Rank() - 1), matxInvalidSize); - MATX_ASSERT(c.Size(RANK - 2) == a.Size(TensorTypeA::Rank() - 2), matxInvalidSize); // Ensure batch dimensions are equal for (int i = 0; i < RANK - 2; i++) { @@ -208,7 +270,17 @@ class MatMulCUDAHandle_t { } // This must come before the things below to properly set class parameters - params_ = GetGemmParams(c, a, b); + params_ = GetGemmParams(c, a, b, opA, opB); + + MATX_ASSERT(GemmOpCols(params_.opA, params_.a_rows, params_.a_cols) == + GemmOpRows(params_.opB, params_.b_rows, params_.b_cols), + matxInvalidSize); + MATX_ASSERT(c.Size(RANK - 1) == + GemmOpCols(params_.opB, params_.b_rows, params_.b_cols), + matxInvalidSize); + MATX_ASSERT(c.Size(RANK - 2) == + GemmOpRows(params_.opA, params_.a_rows, params_.a_cols), + matxInvalidSize); if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB for Hopper+: @@ -265,7 +337,9 @@ class MatMulCUDAHandle_t { } static detail::MatMulCUDAParams_t GetGemmParams(TensorTypeC &c, const TensorTypeA &a, - const TensorTypeB &b) + const TensorTypeB &b, + cublasOperation_t requestedOpA = CUBLAS_OP_N, + cublasOperation_t requestedOpB = CUBLAS_OP_N) { /* If a user passes in a tensor where the last two dimensions are transposed we retain the original size parameters, but tell the underlying libraries that the tensors are @@ -282,6 +356,12 @@ class MatMulCUDAHandle_t { params.a_planar = is_planar_complex_v; params.b_planar = is_planar_complex_v; params.c_planar = is_planar_complex_v; + params.a_rows = a.Size(TensorTypeA::Rank() - 2); + params.a_cols = a.Size(TensorTypeA::Rank() - 1); + params.b_rows = b.Size(TensorTypeB::Rank() - 2); + params.b_cols = b.Size(TensorTypeB::Rank() - 1); + params.c_rows = c.Size(RANK - 2); + params.c_cols = c.Size(RANK - 1); // Batches params.batch = 1; @@ -361,12 +441,11 @@ class MatMulCUDAHandle_t { else if constexpr (PROV == PROVIDER_TYPE_CUTLASS) { params.opA = CUBLAS_OP_N; params.opB = CUBLAS_OP_N; - params.m = static_cast(b.Size(TensorTypeB::Rank() - 1)); - params.n = static_cast(a.Size(TensorTypeA::Rank() - 2)); - params.k = - static_cast(a.Size(TensorTypeA::Rank() - 2)); // Gemm Problem dimensions - params.lda = static_cast(b.Stride(TensorTypeB::Rank() - 1)); - params.ldb = static_cast(a.Stride(TensorTypeA::Rank() - 1)); + params.m = params.a_rows; + params.n = params.b_cols; + params.k = params.a_cols; // Gemm Problem dimensions + params.lda = static_cast(a.Stride(TensorTypeA::Rank() - 2)); + params.ldb = static_cast(b.Stride(TensorTypeB::Rank() - 2)); params.ldc = static_cast(c.Stride(RANK - 1)); } } @@ -374,24 +453,27 @@ class MatMulCUDAHandle_t { if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) { if constexpr (is_complex_half_v) { // For half complex we always copy to a new tensor so it is always cublas op N - params.opA = CUBLAS_OP_N; + params.orderA = MEM_ORDER_ROW_MAJOR; } else if ( a.Stride(TensorTypeA::Rank()-1) > 1 // last stride > 1 || (a.Stride(TensorTypeA::Rank()-1) == 1 && a.Stride(TensorTypeA::Rank()-2) == 1 && a.Size(TensorTypeA::Rank()-1) != 1)) { // last strides both equal 1 and size > 1 - params.opA = CUBLAS_OP_T; + params.orderA = MEM_ORDER_COL_MAJOR; } else { // otherwise row major - params.opA = CUBLAS_OP_N; + params.orderA = MEM_ORDER_ROW_MAJOR; } if constexpr (is_complex_half_v) { // For half complex we always copy to a new tensor so it is always cublas op N - params.opB = CUBLAS_OP_N; + params.orderB = MEM_ORDER_ROW_MAJOR; } else if ( b.Stride(TensorTypeB::Rank()-1) > 1 // last stride > 1 || (b.Stride(TensorTypeB::Rank()-1) == 1 && b.Stride(TensorTypeB::Rank()-2) == 1 && b.Size(TensorTypeB::Rank()-1) != 1)) { // last strides both equal 1 and size > 1 - params.opB = CUBLAS_OP_T; + params.orderB = MEM_ORDER_COL_MAJOR; } else { // otherwise row major - params.opB = CUBLAS_OP_N; + params.orderB = MEM_ORDER_ROW_MAJOR; } + params.opA = requestedOpA; + params.opB = requestedOpB; + params.a_rows = a.Size(TensorTypeA::Rank() - 2); params.a_cols = a.Size(TensorTypeA::Rank() - 1); params.b_rows = b.Size(TensorTypeB::Rank() - 2); @@ -400,7 +482,7 @@ class MatMulCUDAHandle_t { // set lda/ldb according to transpose modes. If we pass in a cloned tensor the second stride will be // 0, which cuBLAS doesn't like even though it's unused. Set it to something that it would be if the // matrix had more than 1 row. - if (params.opB == CUBLAS_OP_T) { + if (params.orderB == MEM_ORDER_COL_MAJOR) { params.ldb = b.Stride(TensorTypeB::Rank() - 1); } else { @@ -408,7 +490,7 @@ class MatMulCUDAHandle_t { params.ldb = (params.ldb == 0) ? b.Size(TensorTypeB::Rank() - 1) : params.ldb; } - if (params.opA == CUBLAS_OP_T) { + if (params.orderA == MEM_ORDER_COL_MAJOR) { params.lda = a.Stride(TensorTypeA::Rank() - 1); } else { @@ -425,8 +507,11 @@ class MatMulCUDAHandle_t { params.ldb = b.Size(TensorTypeB::Rank()-1); } - params.c_rows = params.a_rows; - params.c_cols = params.b_cols; + params.m = GemmOpRows(params.opA, params.a_rows, params.a_cols); + params.k = GemmOpCols(params.opA, params.a_rows, params.a_cols); + params.n = GemmOpCols(params.opB, params.b_rows, params.b_cols); + params.c_rows = params.m; + params.c_cols = params.n; params.ldc = c.Stride(RANK - 2); // For complex half paths we launch as planar row-major. Use compact @@ -440,10 +525,15 @@ class MatMulCUDAHandle_t { else if constexpr (PROV == PROVIDER_TYPE_CUTLASS) { params.opA = CUBLAS_OP_N; params.opB = CUBLAS_OP_N; - params.m = static_cast(a.Size(TensorTypeA::Rank() - 2)); - params.n = static_cast(b.Size(TensorTypeB::Rank() - 1)); - params.k = - static_cast(a.Size(TensorTypeA::Rank() - 1)); // Gemm Problem dimensions + params.a_rows = a.Size(TensorTypeA::Rank() - 2); + params.a_cols = a.Size(TensorTypeA::Rank() - 1); + params.b_rows = b.Size(TensorTypeB::Rank() - 2); + params.b_cols = b.Size(TensorTypeB::Rank() - 1); + params.c_rows = c.Size(RANK - 2); + params.c_cols = c.Size(RANK - 1); + params.m = params.a_rows; + params.n = params.b_cols; + params.k = params.a_cols; // Gemm Problem dimensions params.lda = static_cast(a.Stride(TensorTypeA::Rank() - 2)); params.ldb = static_cast(b.Stride(TensorTypeB::Rank() - 2)); params.ldc = static_cast(c.Stride(RANK - 2)); @@ -572,17 +662,16 @@ class MatMulCUDAHandle_t { cublasLtOrder_t rowOrder = CUBLASLT_ORDER_ROW; cublasLtOrder_t colOrder = CUBLASLT_ORDER_COL; - auto op = CUBLAS_OP_N; // A operation ret = cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op, - sizeof(op)); + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, ¶ms_.opA, + sizeof(params_.opA)); MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // B operation ret = cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op, - sizeof(op)); + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, ¶ms_.opB, + sizeof(params_.opB)); MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // Update this later when we're more flexible on compute type @@ -623,7 +712,7 @@ class MatMulCUDAHandle_t { MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // Matrix data order - if (params_.opA == CUBLAS_OP_T) { + if (params_.orderA == MEM_ORDER_COL_MAJOR) { ret = cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder, sizeof(colOrder)); @@ -635,7 +724,7 @@ class MatMulCUDAHandle_t { } MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); - if (params_.opB == CUBLAS_OP_T) { + if (params_.orderB == MEM_ORDER_COL_MAJOR) { ret = cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder, sizeof(colOrder)); @@ -647,9 +736,16 @@ class MatMulCUDAHandle_t { } MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); - ret = cublasLtMatrixLayoutSetAttribute( - Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, - sizeof(rowOrder)); + if (params_.orderC == MEM_ORDER_COL_MAJOR) { + ret = cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder, + sizeof(colOrder)); + } + else { + ret = cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, + sizeof(rowOrder)); + } MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); ret = cublasLtMatrixLayoutSetAttribute( @@ -1126,6 +1222,11 @@ struct MatMulCUDAParamsKeyHash { return std::hash()(k.m) + std::hash()(k.n) + std::hash()(k.k) + std::hash()(k.batch) + std::hash()(k.prov) + + std::hash()(k.opA) + + std::hash()(k.opB) + + std::hash()(k.orderA) + + std::hash()(k.orderB) + + std::hash()(k.orderC) + std::hash()(static_cast(k.a_planar)) + std::hash()(static_cast(k.b_planar)) + std::hash()(static_cast(k.c_planar)) + @@ -1150,6 +1251,8 @@ struct MatMulCUDAParamsKeyEq { l.ldb == t.ldb && l.ldc == t.ldc && l.batch == t.batch && l.prov == t.prov && l.dtype == t.dtype && l.opA == t.opA && l.opB == t.opB && l.rank == t.rank && + l.orderA == t.orderA && l.orderB == t.orderB && + l.orderC == t.orderC && l.a_planar == t.a_planar && l.b_planar == t.b_planar && l.c_planar == t.c_planar; @@ -1180,6 +1283,81 @@ __MATX_INLINE__ auto getCublasSupportedTensor( const Op &in, cudaStream_t stream return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); } +namespace detail { + +template +__MATX_INLINE__ void CopyCublasInputIfNeeded(TensorType &tensor, const Op &op, + cudaStream_t stream) +{ + if constexpr (!is_matx_transform_op()) { + if (!tensor.isSameView(op)) { + (tensor = op).run(stream); + } + } +} + +template +__MATX_INLINE__ void WithMatmulOperand(const Op &op, cudaStream_t stream, + Func &&func) +{ + using OpType = remove_cvref_t; + constexpr bool can_use_metadata_op = + PROV == PROVIDER_TYPE_CUBLASLT && + std::is_same_v && + CanUseCublasLtConjTransposeOp(); + + if constexpr (can_use_metadata_op && is_hermitian_trans_op_v) { + const auto &input = op.Input(); + auto tensor = getCublasSupportedTensor(input, stream); + CopyCublasInputIfNeeded(tensor, input, stream); + func(tensor, MatMulConjTransposeOp()); + } + else if constexpr (can_use_metadata_op && is_conj_tensor_view_unary_op_v) { + const auto &input = op.Input(); + auto transposed = transpose_matrix(input); + auto tensor = getCublasSupportedTensor(transposed, stream); + CopyCublasInputIfNeeded(tensor, transposed, stream); + func(tensor, MatMulConjTransposeOp()); + } + else { + auto tensor = getCublasSupportedTensor(op, stream); + CopyCublasInputIfNeeded(tensor, op, stream); + func(tensor, CUBLAS_OP_N); + } +} + +template +void MatMulCUDAExecPrepared(TensorTypeC &c, const TensorTypeA &a, + const TensorTypeB &b, const cudaExecutor &exec, + cudaStream_t stream, float alpha, float beta, + cublasOperation_t opA, cublasOperation_t opB) +{ + auto params = + detail::MatMulCUDAHandle_t:: + GetGemmParams(c, a, b, opA, opB); + params.stream = stream; + + using cache_val_type = + detail::MatMulCUDAHandle_t; + auto cache_id = detail::GetCacheIdFromType(); + MATX_LOG_DEBUG("MatMul transform: cache_id={}", cache_id); + detail::GetCache().LookupAndExec( + cache_id, + params, + [&]() { + return std::make_shared(c, a, b, opA, opB); + }, + [&](std::shared_ptr cache_type) { + cache_type->Exec(c, a, b, stream, alpha, beta); + }, + exec + ); +} + +} // end namespace detail + /** * Run a GEMM without a plan * @@ -1236,22 +1414,8 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, "Combination of A/B/C types are not supported"); // CublasLt does not support operators and certain transpose modes. - // Grab a suppported tensor here and copy in if necessary. + // Grab a suppported C tensor here and copy in if necessary. auto c = getCublasSupportedTensor(C, stream); - auto a = getCublasSupportedTensor(A_, stream); - auto b = getCublasSupportedTensor(B_, stream); - - typedef decltype(c) ctype; - typedef decltype(a) atype; - typedef decltype(b) btype; - - if(!is_matx_transform_op() && !a.isSameView(A_)) { - (a = A_).run(stream); - } - - if(!is_matx_transform_op() && !b.isSameView(B_)) { - (b = B_).run(stream); - } if(beta != 0 && !c.isSameView(C)) { (c = C).run(stream); @@ -1262,31 +1426,28 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, // Use the identity CT = BT * AT to do the transpose through the gemm automatically. Note we only want to do this transpose if // the rightmost stride is !=1 or this function will be an infinite recursion. if ( c.Stride(c.Rank()-2) == 1 && c.Stride(c.Rank()-1) > 1 ) { // column major check + auto a = getCublasSupportedTensor(A_, stream); + auto b = getCublasSupportedTensor(B_, stream); + detail::CopyCublasInputIfNeeded(a, A_, stream); + detail::CopyCublasInputIfNeeded(b, B_, stream); + // Column major matmul_impl(transpose_matrix(c), transpose_matrix(b), transpose_matrix(a), exec, alpha, beta); } else #endif { - // Get parameters required by these tensors - auto params = - detail::MatMulCUDAHandle_t::GetGemmParams(c, a, b); - params.stream = stream; - - using cache_val_type = detail::MatMulCUDAHandle_t; - auto cache_id = detail::GetCacheIdFromType(); - MATX_LOG_DEBUG("MatMul transform: cache_id={}", cache_id); - detail::GetCache().LookupAndExec( - cache_id, - params, - [&]() { - return std::make_shared(c, a, b); - }, - [&](std::shared_ptr cache_type) { - cache_type->Exec(c, a, b, stream, alpha, beta); - }, - exec - ); - } + detail::WithMatmulOperand( + A_, stream, + [&](const auto &a, cublasOperation_t opA) { + detail::WithMatmulOperand( + B_, stream, + [&](const auto &b, cublasOperation_t opB) { + detail::MatMulCUDAExecPrepared, + remove_cvref_t, PROV>( + c, a, b, exec, stream, alpha, beta, opA, opB); + }); + }); + } // if c and C are not the same then we need to copy results out. if(!c.isSameView(C)) { diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index 286c47cdd..c6361f152 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -85,11 +85,32 @@ class MatMulTestFloatNonComplexTypes : public MatMulTest { template class MatMulTestComplexHalfPlanarTypes : public MatMulTest { }; +template +class MatMulTestComplexNonHalfCUDA : public MatMulTest { +}; +template +class MatMulTestFloatNonComplexNonHalfCUDA : public MatMulTest { +}; TYPED_TEST_SUITE(MatMulTestFloatTypes, MatXTypesFloatAllExecs); TYPED_TEST_SUITE(MatMulTestFloatNonHalfTypes, MatXFloatNonHalfTypesAllExecs); TYPED_TEST_SUITE(MatMulTestFloatNonComplexTypes, MatXTypesFloatNonComplexAllExecs); TYPED_TEST_SUITE(MatMulTestComplexHalfPlanarTypes, MatXComplexHalfPlanarTypesAllExecs); +TYPED_TEST_SUITE(MatMulTestComplexNonHalfCUDA, MatXComplexNonHalfTypesCUDAExec); +TYPED_TEST_SUITE(MatMulTestFloatNonComplexNonHalfCUDA, MatXFloatNonComplexNonHalfTypesCUDAExec); + +template +T MatMulHermitianTestValue(index_t i, index_t j, index_t batch = 0) +{ + const float r = static_cast((batch + 1) * 13 + i * 3 - j * 2); + const float im = static_cast((batch + 1) * 5 - i + j * 4); + if constexpr (is_cuda_complex_v) { + return T{r / 17.0f, im / 19.0f}; + } + else { + return static_cast(r / 17.0f); + } +} template struct float_to_complex @@ -220,6 +241,196 @@ TYPED_TEST(MatMulTestFloatTypes, SmallRectBTranspose) MATX_EXIT_HANDLER(); } +TYPED_TEST(MatMulTestComplexNonHalfCUDA, SmallRectAHermitian) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } else { + constexpr index_t m = 5; + constexpr index_t k = 3; + constexpr index_t n = 4; + tensor_t a{{m, k}}; + tensor_t b{{m, n}}; + tensor_t c_hermitian{{k, n}}; + tensor_t c_conj_transpose{{k, n}}; + tensor_t c_conj_hermitian{{k, n}}; + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < k; j++) { + a(i, j) = MatMulHermitianTestValue(i, j); + } + } + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + b(i, j) = MatMulHermitianTestValue(i, j + k); + } + } + + (c_hermitian = matmul(hermitianT(a), b)).run(this->exec); + (c_conj_transpose = matmul(conj(transpose_matrix(a)), b)).run(this->exec); + (c_conj_hermitian = matmul(conj(hermitianT(a)), b)).run(this->exec); + this->exec.sync(); + + for (index_t i = 0; i < k; i++) { + for (index_t j = 0; j < n; j++) { + TestType expected{}; + TestType expected_conj_hermitian{}; + for (index_t p = 0; p < m; p++) { + expected += detail::scalar_internal_conj(a(p, i)) * b(p, j); + expected_conj_hermitian += a(p, i) * b(p, j); + } + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_hermitian(i, j), expected, this->thresh)); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_conj_transpose(i, j), expected, this->thresh)); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_conj_hermitian(i, j), + expected_conj_hermitian, + this->thresh)); + } + } + } + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(MatMulTestComplexNonHalfCUDA, SmallRectBHermitian) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } else { + constexpr index_t m = 4; + constexpr index_t k = 5; + constexpr index_t n = 3; + tensor_t a{{m, k}}; + tensor_t b{{n, k}}; + tensor_t c_hermitian{{m, n}}; + tensor_t c_conj_transpose{{m, n}}; + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < k; j++) { + a(i, j) = MatMulHermitianTestValue(i, j); + } + } + + for (index_t i = 0; i < n; i++) { + for (index_t j = 0; j < k; j++) { + b(i, j) = MatMulHermitianTestValue(i + m, j); + } + } + + (c_hermitian = matmul(a, hermitianT(b))).run(this->exec); + (c_conj_transpose = matmul(a, conj(transpose_matrix(b)))).run(this->exec); + this->exec.sync(); + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + TestType expected{}; + for (index_t p = 0; p < k; p++) { + expected += a(i, p) * detail::scalar_internal_conj(b(j, p)); + } + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_hermitian(i, j), expected, this->thresh)); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_conj_transpose(i, j), expected, this->thresh)); + } + } + } + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(MatMulTestComplexNonHalfCUDA, BatchedAHermitian) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } else { + constexpr index_t batches = 2; + constexpr index_t m = 4; + constexpr index_t k = 3; + constexpr index_t n = 5; + tensor_t a{{batches, m, k}}; + tensor_t b{{batches, m, n}}; + tensor_t c{{batches, k, n}}; + + for (index_t batch = 0; batch < batches; batch++) { + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < k; j++) { + a(batch, i, j) = MatMulHermitianTestValue(i, j, batch); + } + } + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + b(batch, i, j) = MatMulHermitianTestValue(i, j + k, batch); + } + } + } + + (c = matmul(hermitianT(a), b)).run(this->exec); + this->exec.sync(); + + for (index_t batch = 0; batch < batches; batch++) { + for (index_t i = 0; i < k; i++) { + for (index_t j = 0; j < n; j++) { + TestType expected{}; + for (index_t p = 0; p < m; p++) { + expected += detail::scalar_internal_conj(a(batch, p, i)) * b(batch, p, j); + } + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c(batch, i, j), expected, this->thresh)); + } + } + } + } + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(MatMulTestFloatNonComplexNonHalfCUDA, SmallRectRealHermitian) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } else { + constexpr index_t m = 5; + constexpr index_t k = 3; + constexpr index_t n = 4; + tensor_t a{{m, k}}; + tensor_t b{{m, n}}; + tensor_t c{{k, n}}; + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < k; j++) { + a(i, j) = MatMulHermitianTestValue(i, j); + } + } + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + b(i, j) = MatMulHermitianTestValue(i, j + k); + } + } + + (c = matmul(hermitianT(a), b)).run(this->exec); + this->exec.sync(); + + for (index_t i = 0; i < k; i++) { + for (index_t j = 0; j < n; j++) { + TestType expected{}; + for (index_t p = 0; p < m; p++) { + expected += a(p, i) * b(p, j); + } + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c(i, j), expected, this->thresh)); + } + } + } + MATX_EXIT_HANDLER(); +} + TYPED_TEST(MatMulTestFloatNonHalfTypes, SmallRectCTranspose) { MATX_ENTER_HANDLER();