diff --git a/.github/workflows/sha2.yml b/.github/workflows/sha2.yml index 3efb039ef..fdb9e9a00 100644 --- a/.github/workflows/sha2.yml +++ b/.github/workflows/sha2.yml @@ -169,13 +169,13 @@ jobs: - run: cargo install cross --git https://github.com/cross-rs/cross - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb riscv32-zknh: runs-on: ubuntu-latest @@ -188,13 +188,13 @@ jobs: components: rust-src - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master diff --git a/sha2/src/lib.rs b/sha2/src/lib.rs index 1efc78b6f..9d9c9f448 100644 --- a/sha2/src/lib.rs +++ b/sha2/src/lib.rs @@ -10,6 +10,7 @@ any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), feature(riscv_ext_intrinsics) )] +#![allow(clippy::needless_range_loop)] #[cfg(all( any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), diff --git a/sha2/src/sha256.rs b/sha2/src/sha256.rs index 4e0849020..6d5896f13 100644 --- a/sha2/src/sha256.rs +++ b/sha2/src/sha256.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index fe950bdc8..7477c6409 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -18,8 +21,34 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (x & z) ^ (y & z) } -#[allow(clippy::identity_op)] -fn round(state: &mut [u32; 8], block: &[u32; 16]) { +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u32]) -> u32 { + assert!(R < k.len()); + let dst; + #[cfg(target_arch = "riscv64")] + unsafe { + core::arch::asm!( + "lwu {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + #[cfg(target_arch = "riscv32")] + unsafe { + core::arch::asm!( + "lw {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + dst +} + +fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { let n = K32.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -34,92 +63,65 @@ fn round(state: &mut [u32; 8], block: &[u32; 16]) { state[h] = state[h] .wrapping_add(unsafe { sha256sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K32[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(opaque_load::(k)) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha256sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha256sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha256sig0(block[(R + 1) % 16]) }); } +#[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round::<48>(s, b); - round::<49>(s, b); - round::<50>(s, b); - round::<51>(s, b); - round::<52>(s, b); - round::<53>(s, b); - round::<54>(s, b); - round::<55>(s, b); - round::<56>(s, b); - round::<57>(s, b); - round::<58>(s, b); - round::<59>(s, b); - round::<60>(s, b); - round::<61>(s, b); - round::<62>(s, b); - round::<63>(s, b); + for i in 0..3 { + let k = &K32[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K32[48..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); @@ -127,7 +129,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index 98375cce7..bba510a30 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -43,9 +46,7 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u32; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha256sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -54,14 +55,13 @@ fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..48 { - round_schedule(s, b, i); - } - for i in 48..64 { - round(s, b, i); + for r in 0..64 { + round(&mut s, &block, r); + if r < 48 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -70,7 +70,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_utils.rs b/sha2/src/sha256/riscv_zknh_utils.rs new file mode 100644 index 000000000..d75a0b1c1 --- /dev/null +++ b/sha2/src/sha256/riscv_zknh_utils.rs @@ -0,0 +1,80 @@ +use core::{arch::asm, ptr}; + +#[inline(always)] +pub(super) fn load_block(block: &[u8; 64]) -> [u32; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[inline(always)] +fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { + let p: *const u32 = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let mut res = [0u32; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(p.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[inline(always)] +fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut res = [0u32; 16]; + + /// Use LW instruction on RV32 and LWU on RV64 + #[cfg(target_arch = "riscv32")] + macro_rules! lw { + ($r:literal) => { + concat!("lw ", $r) + }; + } + #[cfg(target_arch = "riscv64")] + macro_rules! lw { + ($r:literal) => { + concat!("lwu ", $r) + }; + } + + unsafe { + asm!( + lw!("{left}, 0({bp})"), // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u32; + unsafe { + asm!( + lw!("{right}, 16 * 4({bp})"), // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +} diff --git a/sha2/src/sha512.rs b/sha2/src/sha512.rs index 08e802fcc..206792665 100644 --- a/sha2/src/sha512.rs +++ b/sha2/src/sha512.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 31a327ebb..7be35ee84 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -46,7 +49,40 @@ fn maj(x: u64, y: u64, z: u64) -> u64 { (x & y) ^ (x & z) ^ (y & z) } -fn round(state: &mut [u64; 8], block: &[u64; 16]) { +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u64]) -> u64 { + use core::arch::asm; + assert!(R < k.len()); + #[cfg(target_arch = "riscv64")] + unsafe { + let dst; + asm!( + "ld {dst}, {N}({k})", + N = const 8 * R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + dst + } + #[cfg(target_arch = "riscv32")] + unsafe { + let [hi, lo]: [u32; 2]; + asm!( + "lw {lo}, {N1}({k})", + "lw {hi}, {N2}({k})", + N1 = const 8 * R, + N2 = const 8 * R + 4, + k = in(reg) k.as_ptr(), + lo = out(reg) lo, + hi = out(reg) hi, + options(pure, readonly, nostack, preserves_flags), + ); + ((hi as u64) << 32) | (lo as u64) + } +} + +fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { let n = K64.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -61,19 +97,18 @@ fn round(state: &mut [u64; 8], block: &[u64; 16]) { state[h] = state[h] .wrapping_add(unsafe { sha512sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K64[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(opaque_load::(k)) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha512sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], k: &[u64]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha512sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha512sig0(block[(R + 1) % 16]) }); @@ -83,86 +118,43 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round_schedule::<48>(s, b); - round_schedule::<49>(s, b); - round_schedule::<50>(s, b); - round_schedule::<51>(s, b); - round_schedule::<52>(s, b); - round_schedule::<53>(s, b); - round_schedule::<54>(s, b); - round_schedule::<55>(s, b); - round_schedule::<56>(s, b); - round_schedule::<57>(s, b); - round_schedule::<58>(s, b); - round_schedule::<59>(s, b); - round_schedule::<60>(s, b); - round_schedule::<61>(s, b); - round_schedule::<62>(s, b); - round_schedule::<63>(s, b); - round::<64>(s, b); - round::<65>(s, b); - round::<66>(s, b); - round::<67>(s, b); - round::<68>(s, b); - round::<69>(s, b); - round::<70>(s, b); - round::<71>(s, b); - round::<72>(s, b); - round::<73>(s, b); - round::<74>(s, b); - round::<75>(s, b); - round::<76>(s, b); - round::<77>(s, b); - round::<78>(s, b); - round::<79>(s, b); + for i in 0..4 { + let k = &K64[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K64[64..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); @@ -170,7 +162,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs index 92e984c52..729157c45 100644 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ b/sha2/src/sha512/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -71,9 +74,7 @@ fn round(state: &mut [u64; 8], block: &[u64; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u64; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha512sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -82,14 +83,13 @@ fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..64 { - round_schedule(s, b, i); - } - for i in 64..80 { - round(s, b, i); + for r in 0..80 { + round(&mut s, &block, r); + if r < 64 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -98,7 +98,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_utils.rs b/sha2/src/sha512/riscv_zknh_utils.rs new file mode 100644 index 000000000..0b4746069 --- /dev/null +++ b/sha2/src/sha512/riscv_zknh_utils.rs @@ -0,0 +1,129 @@ +use core::{arch::asm, ptr}; + +#[inline(always)] +pub(super) fn load_block(block: &[u8; 128]) -> [u64; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[cfg(target_arch = "riscv32")] +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let p: *const [u32; 32] = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let block = unsafe { &*p }; + let mut res = [0u64; 16]; + for i in 0..16 { + let a = block[2 * i].to_be() as u64; + let b = block[2 * i + 1].to_be() as u64; + res[i] = (a << 32) | b; + } + res +} + +#[cfg(target_arch = "riscv64")] +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let block_ptr: *const u64 = block.as_ptr().cast(); + debug_assert!(block_ptr.is_aligned()); + let mut res = [0u64; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(block_ptr.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[cfg(target_arch = "riscv32")] +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut block32 = [0u32; 32]; + + unsafe { + asm!( + "lw {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..31 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + block32[i] = left | (right << off2); + left = right >> off1; + } + + let right: u32; + unsafe { + asm!( + "lw {right}, 32 * 4({bp})", // right = ptr::read(bp.add(32)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + block32[31] = left | right; + + let mut block64 = [0u64; 16]; + for i in 0..16 { + let a = block32[2 * i].to_be() as u64; + let b = block32[2 * i + 1].to_be() as u64; + block64[i] = (a << 32) | b; + } + block64 +} + +#[cfg(target_arch = "riscv64")] +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 64; + let off2 = (64 - off1) % 64; + let bp: *const u64 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u64; + let mut res = [0u64; 16]; + + unsafe { + asm!( + "ld {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u64; + unsafe { + asm!( + "ld {right}, 16 * 8({bp})", // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +}