diff --git a/provekit/common/src/prefix_covector.rs b/provekit/common/src/prefix_covector.rs index e0a813e6d..7272dc832 100644 --- a/provekit/common/src/prefix_covector.rs +++ b/provekit/common/src/prefix_covector.rs @@ -1,6 +1,7 @@ use { crate::FieldElement, ark_std::{One, Zero}, + rayon::prelude::*, whir::algebra::{dot, linear_form::LinearForm, multilinear_extend}, }; @@ -76,11 +77,18 @@ impl LinearForm for PrefixCovector { } fn accumulate(&self, accumulator: &mut [FieldElement], scalar: FieldElement) { - for (acc, val) in accumulator[..self.vector.len()] - .iter_mut() - .zip(&self.vector) - { - *acc += scalar * *val; + let accumulator = &mut accumulator[..self.vector.len()]; + if self.vector.len() > whir::utils::workload_size::() { + accumulator + .par_iter_mut() + .zip(self.vector.par_iter()) + .for_each(|(acc, val)| { + *acc += scalar * *val; + }); + } else { + for (acc, val) in accumulator.iter_mut().zip(&self.vector) { + *acc += scalar * *val; + } } } } diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 17a7c4299..527594646 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -208,9 +208,9 @@ pub fn transpose_r1cs_matrices(r1cs: &R1CS) -> (SparseMatrix, SparseMatrix, Spar /// external row. #[instrument(skip_all)] pub fn multiply_transposed_by_eq_alpha( - at: &SparseMatrix, - bt: &SparseMatrix, - ct: &SparseMatrix, + at: SparseMatrix, + bt: SparseMatrix, + ct: SparseMatrix, alpha: &[FieldElement], r1cs: &R1CS, ) -> [Vec; 3] { @@ -237,8 +237,8 @@ pub fn multiply_transposed_by_eq_alpha( #[instrument(skip_all)] pub fn calculate_external_row_of_r1cs_matrices( alpha: &[FieldElement], - r1cs: &R1CS, + r1cs: R1CS, ) -> [Vec; 3] { - let (at, bt, ct) = transpose_r1cs_matrices(r1cs); - multiply_transposed_by_eq_alpha(&at, &bt, &ct, alpha, r1cs) + let (at, bt, ct) = transpose_r1cs_matrices(&r1cs); + multiply_transposed_by_eq_alpha(at, bt, ct, alpha, &r1cs) } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 8578cd2df..f35255efa 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -151,7 +151,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ); drop(full_witness); - let alphas = calculate_external_row_of_r1cs_matrices(&alpha, &r1cs); + let alphas = calculate_external_row_of_r1cs_matrices(&alpha, r1cs); let (x, public_weight) = get_public_weights(public_inputs, &mut merlin, self.m); let blinding_offset = blinding.offset; diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 744117b6c..28f613790 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -91,13 +91,8 @@ impl WhirR1CSVerifier for WhirR1CSScheme { ); let x: FieldElement = arthur.verifier_message(); - let alphas = multiply_transposed_by_eq_alpha( - &at, - &bt, - &ct, - &data_from_sumcheck_verifier.alpha, - r1cs, - ); + let alphas = + multiply_transposed_by_eq_alpha(at, bt, ct, &data_from_sumcheck_verifier.alpha, r1cs); let blinding_eval = data_from_sumcheck_verifier.blinding_eval; let blinding_weights = expand_powers::<4>(&data_from_sumcheck_verifier.alpha);