diff --git a/kernel/generic/dot.c b/kernel/generic/dot.c index ba7c64a9aa..4303eb1f3d 100644 --- a/kernel/generic/dot.c +++ b/kernel/generic/dot.c @@ -47,11 +47,46 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) if ( (inc_x == 1) && (inc_y == 1) ) { -#if V_SIMD && !defined(DSDOT) - const int vstep = v_nlanes_f32; - const int unrollx4 = n & (-vstep * 4); - const int unrollx = n & -vstep; - v_f32 vsum0 = v_zero_f32(); + #if defined(DOUBLE) && V_SIMD && V_SIMD_F64 && !defined(DSDOT) + const int vstep = v_nlanes_f64; + const int unrollx4 = n & (-vstep * 4); + const int unrollx = n & -vstep; + v_f64 vsum0 = v_zero_f64(); + v_f64 vsum1 = v_zero_f64(); + v_f64 vsum2 = v_zero_f64(); + v_f64 vsum3 = v_zero_f64(); + while(i < unrollx4) + { + vsum0 = v_muladd_f64( + v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0 + ); + vsum1 = v_muladd_f64( + v_loadu_f64(x + i + vstep), v_loadu_f64(y + i + vstep), vsum1 + ); + vsum2 = v_muladd_f64( + v_loadu_f64(x + i + vstep*2), v_loadu_f64(y + i + vstep*2), vsum2 + ); + vsum3 = v_muladd_f64( + v_loadu_f64(x + i + vstep*3), v_loadu_f64(y + i + vstep*3), vsum3 + ); + i += vstep*4; + } + vsum0 = v_add_f64( + v_add_f64(vsum0, vsum1), v_add_f64(vsum2 , vsum3) + ); + while(i < unrollx) + { + vsum0 = v_muladd_f64( + v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0 + ); + i += vstep; + } + dot = v_sum_f64(vsum0); + #elif V_SIMD && !defined(DSDOT) + const int vstep = v_nlanes_f32; + const int unrollx4 = n & (-vstep * 4); + const int unrollx = n & -vstep; + v_f32 vsum0 = v_zero_f32(); v_f32 vsum1 = v_zero_f32(); v_f32 vsum2 = v_zero_f32(); v_f32 vsum3 = v_zero_f32(); @@ -82,10 +117,54 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) i += vstep; } dot = v_sum_f32(vsum0); -#elif defined(DSDOT) - int n1 = n & -4; - for (; i < n1; i += 4) - { + #elif defined(DSDOT) && defined(ARCH_WASM) && V_SIMD && V_SIMD_F64 + const int vstep = v_nlanes_f32; + const int unrollx4 = n & (-vstep * 4); + const int unrollx = n & -vstep; + v_f64 vsum0_lo = v_zero_f64(); + v_f64 vsum0_hi = v_zero_f64(); + v_f64 vsum1_lo = v_zero_f64(); + v_f64 vsum1_hi = v_zero_f64(); + v_f64 vsum2_lo = v_zero_f64(); + v_f64 vsum2_hi = v_zero_f64(); + v_f64 vsum3_lo = v_zero_f64(); + v_f64 vsum3_hi = v_zero_f64(); + while(i < unrollx4) + { + v_f32 vx0 = v_loadu_f32(x + i); + v_f32 vy0 = v_loadu_f32(y + i); + v_f32 vx1 = v_loadu_f32(x + i + vstep); + v_f32 vy1 = v_loadu_f32(y + i + vstep); + v_f32 vx2 = v_loadu_f32(x + i + vstep*2); + v_f32 vy2 = v_loadu_f32(y + i + vstep*2); + v_f32 vx3 = v_loadu_f32(x + i + vstep*3); + v_f32 vy3 = v_loadu_f32(y + i + vstep*3); + + vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx0), v_cvt_f32_f64_lo(vy0), vsum0_lo); + vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx0), v_cvt_f32_f64_hi(vy0), vsum0_hi); + vsum1_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx1), v_cvt_f32_f64_lo(vy1), vsum1_lo); + vsum1_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx1), v_cvt_f32_f64_hi(vy1), vsum1_hi); + vsum2_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx2), v_cvt_f32_f64_lo(vy2), vsum2_lo); + vsum2_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx2), v_cvt_f32_f64_hi(vy2), vsum2_hi); + vsum3_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx3), v_cvt_f32_f64_lo(vy3), vsum3_lo); + vsum3_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx3), v_cvt_f32_f64_hi(vy3), vsum3_hi); + i += vstep*4; + } + vsum0_lo = v_add_f64(v_add_f64(vsum0_lo, vsum1_lo), v_add_f64(vsum2_lo, vsum3_lo)); + vsum0_hi = v_add_f64(v_add_f64(vsum0_hi, vsum1_hi), v_add_f64(vsum2_hi, vsum3_hi)); + while(i < unrollx) + { + v_f32 vx = v_loadu_f32(x + i); + v_f32 vy = v_loadu_f32(y + i); + vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx), v_cvt_f32_f64_lo(vy), vsum0_lo); + vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx), v_cvt_f32_f64_hi(vy), vsum0_hi); + i += vstep; + } + dot = v_sum_f64(vsum0_lo) + v_sum_f64(vsum0_hi); + #elif defined(DSDOT) + int n1 = n & -4; + for (; i < n1; i += 4) + { dot += (double) y[i] * (double) x[i] + (double) y[i+1] * (double) x[i+1] + (double) y[i+2] * (double) x[i+2] @@ -133,5 +212,3 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) return(dot); } - - diff --git a/kernel/simd/intrin_wasm.h b/kernel/simd/intrin_wasm.h index 1e04c70127..f55c2f28be 100644 --- a/kernel/simd/intrin_wasm.h +++ b/kernel/simd/intrin_wasm.h @@ -33,6 +33,15 @@ BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c) BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c) { return v_sub_f64(v_mul_f64(a, b), c); } +BLAS_FINLINE v_f64 v_cvt_f32_f64_lo(v_f32 a) +{ return wasm_f64x2_promote_low_f32x4(a); } + +BLAS_FINLINE v_f64 v_cvt_f32_f64_hi(v_f32 a) +{ + v128_t hi = wasm_i32x4_shuffle(a, a, 2, 3, 0, 1); + return wasm_f64x2_promote_low_f32x4(hi); +} + /*************************** * reduction ***************************/ diff --git a/kernel/wasm/KERNEL.WASM128_GENERIC b/kernel/wasm/KERNEL.WASM128_GENERIC index 1f1946a015..7791880168 100644 --- a/kernel/wasm/KERNEL.WASM128_GENERIC +++ b/kernel/wasm/KERNEL.WASM128_GENERIC @@ -55,8 +55,8 @@ DCOPYKERNEL = ../riscv64/copy.c CCOPYKERNEL = ../riscv64/zcopy.c ZCOPYKERNEL = ../riscv64/zcopy.c -SDOTKERNEL = ../riscv64/dot.c -DDOTKERNEL = ../riscv64/dot.c +SDOTKERNEL = ../generic/dot.c +DDOTKERNEL = ../generic/dot.c CDOTKERNEL = ../riscv64/zdot.c ZDOTKERNEL = ../riscv64/zdot.c DSDOTKERNEL = ../generic/dot.c