Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 88 additions & 11 deletions kernel/generic/dot.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,46 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)

if ( (inc_x == 1) && (inc_y == 1) )
{
#if V_SIMD && !defined(DSDOT)
const int vstep = v_nlanes_f32;
const int unrollx4 = n & (-vstep * 4);
const int unrollx = n & -vstep;
v_f32 vsum0 = v_zero_f32();
#if defined(DOUBLE) && V_SIMD && V_SIMD_F64 && !defined(DSDOT)
const int vstep = v_nlanes_f64;
const int unrollx4 = n & (-vstep * 4);
const int unrollx = n & -vstep;
v_f64 vsum0 = v_zero_f64();
v_f64 vsum1 = v_zero_f64();
v_f64 vsum2 = v_zero_f64();
v_f64 vsum3 = v_zero_f64();
while(i < unrollx4)
{
vsum0 = v_muladd_f64(
v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0
);
vsum1 = v_muladd_f64(
v_loadu_f64(x + i + vstep), v_loadu_f64(y + i + vstep), vsum1
);
vsum2 = v_muladd_f64(
v_loadu_f64(x + i + vstep*2), v_loadu_f64(y + i + vstep*2), vsum2
);
vsum3 = v_muladd_f64(
v_loadu_f64(x + i + vstep*3), v_loadu_f64(y + i + vstep*3), vsum3
);
i += vstep*4;
}
vsum0 = v_add_f64(
v_add_f64(vsum0, vsum1), v_add_f64(vsum2 , vsum3)
);
while(i < unrollx)
{
vsum0 = v_muladd_f64(
v_loadu_f64(x + i), v_loadu_f64(y + i), vsum0
);
i += vstep;
}
dot = v_sum_f64(vsum0);
#elif V_SIMD && !defined(DSDOT)
const int vstep = v_nlanes_f32;
const int unrollx4 = n & (-vstep * 4);
const int unrollx = n & -vstep;
v_f32 vsum0 = v_zero_f32();
v_f32 vsum1 = v_zero_f32();
v_f32 vsum2 = v_zero_f32();
v_f32 vsum3 = v_zero_f32();
Expand Down Expand Up @@ -82,10 +117,54 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
i += vstep;
}
dot = v_sum_f32(vsum0);
#elif defined(DSDOT)
int n1 = n & -4;
for (; i < n1; i += 4)
{
#elif defined(DSDOT) && defined(ARCH_WASM) && V_SIMD && V_SIMD_F64
const int vstep = v_nlanes_f32;
const int unrollx4 = n & (-vstep * 4);
const int unrollx = n & -vstep;
v_f64 vsum0_lo = v_zero_f64();
v_f64 vsum0_hi = v_zero_f64();
v_f64 vsum1_lo = v_zero_f64();
v_f64 vsum1_hi = v_zero_f64();
v_f64 vsum2_lo = v_zero_f64();
v_f64 vsum2_hi = v_zero_f64();
v_f64 vsum3_lo = v_zero_f64();
v_f64 vsum3_hi = v_zero_f64();
while(i < unrollx4)
{
v_f32 vx0 = v_loadu_f32(x + i);
v_f32 vy0 = v_loadu_f32(y + i);
v_f32 vx1 = v_loadu_f32(x + i + vstep);
v_f32 vy1 = v_loadu_f32(y + i + vstep);
v_f32 vx2 = v_loadu_f32(x + i + vstep*2);
v_f32 vy2 = v_loadu_f32(y + i + vstep*2);
v_f32 vx3 = v_loadu_f32(x + i + vstep*3);
v_f32 vy3 = v_loadu_f32(y + i + vstep*3);

vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx0), v_cvt_f32_f64_lo(vy0), vsum0_lo);
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx0), v_cvt_f32_f64_hi(vy0), vsum0_hi);
vsum1_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx1), v_cvt_f32_f64_lo(vy1), vsum1_lo);
vsum1_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx1), v_cvt_f32_f64_hi(vy1), vsum1_hi);
vsum2_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx2), v_cvt_f32_f64_lo(vy2), vsum2_lo);
vsum2_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx2), v_cvt_f32_f64_hi(vy2), vsum2_hi);
vsum3_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx3), v_cvt_f32_f64_lo(vy3), vsum3_lo);
vsum3_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx3), v_cvt_f32_f64_hi(vy3), vsum3_hi);
i += vstep*4;
}
vsum0_lo = v_add_f64(v_add_f64(vsum0_lo, vsum1_lo), v_add_f64(vsum2_lo, vsum3_lo));
vsum0_hi = v_add_f64(v_add_f64(vsum0_hi, vsum1_hi), v_add_f64(vsum2_hi, vsum3_hi));
while(i < unrollx)
{
v_f32 vx = v_loadu_f32(x + i);
v_f32 vy = v_loadu_f32(y + i);
vsum0_lo = v_muladd_f64(v_cvt_f32_f64_lo(vx), v_cvt_f32_f64_lo(vy), vsum0_lo);
vsum0_hi = v_muladd_f64(v_cvt_f32_f64_hi(vx), v_cvt_f32_f64_hi(vy), vsum0_hi);
i += vstep;
}
dot = v_sum_f64(vsum0_lo) + v_sum_f64(vsum0_hi);
#elif defined(DSDOT)
int n1 = n & -4;
for (; i < n1; i += 4)
{
dot += (double) y[i] * (double) x[i]
+ (double) y[i+1] * (double) x[i+1]
+ (double) y[i+2] * (double) x[i+2]
Expand Down Expand Up @@ -133,5 +212,3 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
return(dot);

}


9 changes: 9 additions & 0 deletions kernel/simd/intrin_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
{ return v_sub_f64(v_mul_f64(a, b), c); }

BLAS_FINLINE v_f64 v_cvt_f32_f64_lo(v_f32 a)
{ return wasm_f64x2_promote_low_f32x4(a); }

BLAS_FINLINE v_f64 v_cvt_f32_f64_hi(v_f32 a)
{
v128_t hi = wasm_i32x4_shuffle(a, a, 2, 3, 0, 1);
return wasm_f64x2_promote_low_f32x4(hi);
}

/***************************
* reduction
***************************/
Expand Down
4 changes: 2 additions & 2 deletions kernel/wasm/KERNEL.WASM128_GENERIC
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ DCOPYKERNEL = ../riscv64/copy.c
CCOPYKERNEL = ../riscv64/zcopy.c
ZCOPYKERNEL = ../riscv64/zcopy.c

SDOTKERNEL = ../riscv64/dot.c
DDOTKERNEL = ../riscv64/dot.c
SDOTKERNEL = ../generic/dot.c
DDOTKERNEL = ../generic/dot.c
CDOTKERNEL = ../riscv64/zdot.c
ZDOTKERNEL = ../riscv64/zdot.c
DSDOTKERNEL = ../generic/dot.c
Expand Down
Loading