From b6f265f15737b0a6541169c17757004a3d8bfc8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 19:02:27 +0300 Subject: [PATCH 01/10] sha2: improve RISC-V Zknh backend --- sha2/src/lib.rs | 1 + sha2/src/sha256.rs | 2 + sha2/src/sha256/riscv_zknh.rs | 26 +++++- sha2/src/sha256/riscv_zknh_compact.rs | 28 +++--- sha2/src/sha256/riscv_zknh_utils.rs | 77 ++++++++++++++++ sha2/src/sha512.rs | 2 + sha2/src/sha512/riscv_zknh.rs | 26 +++++- sha2/src/sha512/riscv_zknh_compact.rs | 28 +++--- sha2/src/sha512/riscv_zknh_utils.rs | 124 ++++++++++++++++++++++++++ 9 files changed, 280 insertions(+), 34 deletions(-) create mode 100644 sha2/src/sha256/riscv_zknh_utils.rs create mode 100644 sha2/src/sha512/riscv_zknh_utils.rs diff --git a/sha2/src/lib.rs b/sha2/src/lib.rs index 1efc78b6f..9d9c9f448 100644 --- a/sha2/src/lib.rs +++ b/sha2/src/lib.rs @@ -10,6 +10,7 @@ any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), feature(riscv_ext_intrinsics) )] +#![allow(clippy::needless_range_loop)] #[cfg(all( any(sha2_backend = "riscv-zknh", sha2_backend = "riscv-zknh-compact"), diff --git a/sha2/src/sha256.rs b/sha2/src/sha256.rs index 4e0849020..6d5896f13 100644 --- a/sha2/src/sha256.rs +++ b/sha2/src/sha256.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index fe950bdc8..50768c6f2 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -5,8 +5,8 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] +compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -126,8 +126,26 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } } -pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { +#[inline(never)] +fn compress_aligned(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + for block in blocks { + let block = super::riscv_zknh_utils::load_aligned_block(block); + compress_block(state, block); + } +} + +#[cold] +fn compress_unaligned(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + for block in blocks { + let block = super::riscv_zknh_utils::load_unaligned_block(block); compress_block(state, block); } } + +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + if blocks.as_ptr().cast::().is_aligned() { + compress_aligned(state, blocks); + } else { + compress_unaligned(state, blocks); + } +} diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index 98375cce7..cf66ab5c3 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -5,8 +5,8 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] +compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { @@ -43,9 +43,7 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u32; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha256sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -54,14 +52,13 @@ fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..48 { - round_schedule(s, b, i); - } - for i in 48..64 { - round(s, b, i); + for r in 0..64 { + round(&mut s, &block, r); + if r < 48 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -70,7 +67,12 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter().map(super::to_u32s) { + for block in blocks.iter() { + let block = if block.as_ptr().cast::().is_aligned() { + super::riscv_zknh_utils::load_aligned_block(block) + } else { + super::riscv_zknh_utils::load_unaligned_block(block) + }; compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_utils.rs b/sha2/src/sha256/riscv_zknh_utils.rs new file mode 100644 index 000000000..903220156 --- /dev/null +++ b/sha2/src/sha256/riscv_zknh_utils.rs @@ -0,0 +1,77 @@ +use core::{arch::asm, ptr}; + +#[inline(always)] +pub(super) fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { + let p: *const u32 = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let mut res = [0u32; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(p.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[inline(always)] +pub(super) fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut res = [0u32; 16]; + + unsafe { + #[cfg(target_arch = "riscv64")] + asm!( + "lwu {left}, 0({bp})", + "srl {left}, {left}, {off1}", + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + #[cfg(target_arch = "riscv32")] + asm!( + "lw {left}, 0({bp})", + "srl {left}, {left}, {off1}", + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u32; + unsafe { + #[cfg(target_arch = "riscv64")] + asm!( + "lwu {right}, 64({bp})", + "sll {right}, {right}, {off2}", + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + #[cfg(target_arch = "riscv32")] + asm!( + "lw {right}, 64({bp})", + "sll {right}, {right}, {off2}", + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +} diff --git a/sha2/src/sha512.rs b/sha2/src/sha512.rs index 08e802fcc..206792665 100644 --- a/sha2/src/sha512.rs +++ b/sha2/src/sha512.rs @@ -11,12 +11,14 @@ cfg_if::cfg_if! { sha2_backend = "riscv-zknh" ))] { mod riscv_zknh; + mod riscv_zknh_utils; use riscv_zknh::compress; } else if #[cfg(all( any(target_arch = "riscv32", target_arch = "riscv64"), sha2_backend = "riscv-zknh-compact" ))] { mod riscv_zknh_compact; + mod riscv_zknh_utils; use riscv_zknh_compact::compress; } else if #[cfg(target_arch = "aarch64")] { mod soft; diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 31a327ebb..621e02648 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -5,8 +5,8 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] +compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -169,8 +169,26 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } } -pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { +#[inline(never)] +fn compress_aligned(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + for block in blocks { + let block = super::riscv_zknh_utils::load_aligned_block(block); + compress_block(state, block); + } +} + +#[cold] +fn compress_unaligned(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + for block in blocks { + let block = super::riscv_zknh_utils::load_unaligned_block(block); compress_block(state, block); } } + +pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { + if blocks.as_ptr().cast::().is_aligned() { + compress_aligned(state, blocks); + } else { + compress_unaligned(state, blocks); + } +} diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs index 92e984c52..6ecb39875 100644 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ b/sha2/src/sha512/riscv_zknh_compact.rs @@ -5,8 +5,8 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(target_feature = "zknh"))] -compile_error!("riscv-zknh backend requires enabled zknh target feature"); +#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] +compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { @@ -71,9 +71,7 @@ fn round(state: &mut [u64; 8], block: &[u64; 16], r: usize) { } #[inline(always)] -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { - round(state, block, r); - +fn schedule(block: &mut [u64; 16], r: usize) { block[r % 16] = block[r % 16] .wrapping_add(unsafe { sha512sig1(block[(r + 14) % 16]) }) .wrapping_add(block[(r + 9) % 16]) @@ -82,14 +80,13 @@ fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], r: usize) { #[inline(always)] fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { - let s = &mut state.clone(); - let b = &mut block; + let mut s = *state; - for i in 0..64 { - round_schedule(s, b, i); - } - for i in 64..80 { - round(s, b, i); + for r in 0..80 { + round(&mut s, &block, r); + if r < 64 { + schedule(&mut block, r) + } } for i in 0..8 { @@ -98,7 +95,12 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter().map(super::to_u64s) { + for block in blocks.iter() { + let block = if block.as_ptr().cast::().is_aligned() { + super::riscv_zknh_utils::load_aligned_block(block) + } else { + super::riscv_zknh_utils::load_unaligned_block(block) + }; compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_utils.rs b/sha2/src/sha512/riscv_zknh_utils.rs new file mode 100644 index 000000000..d1a7d7549 --- /dev/null +++ b/sha2/src/sha512/riscv_zknh_utils.rs @@ -0,0 +1,124 @@ +use core::{arch::asm, ptr}; + +#[cfg(target_arch = "riscv32")] +#[inline(always)] +pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let p: *const [u32; 32] = block.as_ptr().cast(); + debug_assert!(p.is_aligned()); + let block = unsafe { &*p }; + let mut res = [0u64; 16]; + for i in 0..16 { + let a = block[2 * i].to_be() as u64; + let b = block[2 * i + 1].to_be() as u64; + res[i] = (a << 32) | b; + } + res +} + +#[cfg(target_arch = "riscv64")] +#[inline(always)] +pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { + let block_ptr: *const u64 = block.as_ptr().cast(); + debug_assert!(block_ptr.is_aligned()); + let mut res = [0u64; 16]; + for i in 0..16 { + let val = unsafe { ptr::read(block_ptr.add(i)) }; + res[i] = val.to_be(); + } + res +} + +#[cfg(target_arch = "riscv32")] +#[inline(always)] +pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 32; + let off2 = (32 - off1) % 32; + let bp: *const u32 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u32; + let mut block32 = [0u32; 32]; + + unsafe { + asm!( + "lw {left}, 0({bp})", + "srl {left}, {left}, {off1}", + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + + for i in 0..31 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + block32[i] = left | (right << off2); + left = right >> off1; + } + + let right: u32; + unsafe { + asm!( + "lw {right}, 128({bp})", + "sll {right}, {right}, {off2}", + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + block32[31] = left | right; + + let mut block64 = [0u64; 16]; + for i in 0..16 { + let a = block32[2 * i].to_be() as u64; + let b = block32[2 * i + 1].to_be() as u64; + block64[i] = (a << 32) | b; + } + block64 +} + +#[cfg(target_arch = "riscv64")] +#[inline(always)] +pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { + let offset = (block.as_ptr() as usize) % align_of::(); + debug_assert_ne!(offset, 0); + let off1 = (8 * offset) % 64; + let off2 = (64 - off1) % 64; + let bp: *const u64 = block.as_ptr().wrapping_sub(offset).cast(); + + let mut left: u64; + let mut res = [0u64; 16]; + + unsafe { + asm!( + "ld {left}, 0({bp})", + "srl {left}, {left}, {off1}", + bp = in(reg) bp, + off1 = in(reg) off1, + left = out(reg) left, + options(pure, nostack, readonly, preserves_flags), + ); + } + for i in 0..15 { + let right = unsafe { ptr::read(bp.add(1 + i)) }; + res[i] = (left | (right << off2)).to_be(); + left = right >> off1; + } + + let right: u64; + unsafe { + asm!( + "ld {right}, 128({bp})", + "sll {right}, {right}, {off2}", + bp = in(reg) bp, + off2 = in(reg) off2, + right = out(reg) right, + options(pure, nostack, readonly, preserves_flags), + ); + } + res[15] = (left | right).to_be(); + + res +} From 8fb3a994310ed70580e33fb37b5c96132151e261 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 19:27:40 +0300 Subject: [PATCH 02/10] Use less aggressive inlining in the riscv-zknh backend --- sha2/src/sha256/riscv_zknh.rs | 140 +++++++++++-------------------- sha2/src/sha512/riscv_zknh.rs | 154 +++++++++++----------------------- 2 files changed, 99 insertions(+), 195 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index 50768c6f2..bc73faad0 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -18,8 +18,7 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (x & z) ^ (y & z) } -#[allow(clippy::identity_op)] -fn round(state: &mut [u32; 8], block: &[u32; 16]) { +fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { let n = K32.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -35,117 +34,78 @@ fn round(state: &mut [u32; 8], block: &[u32; 16]) { .wrapping_add(unsafe { sha256sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K32[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(unsafe { core::ptr::read_volatile(&k[R]) }) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha256sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u32; 8], block: &mut [u32; 16], k: &[u32]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha256sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha256sig0(block[(R + 1) % 16]) }); } +#[inline(always)] fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round::<48>(s, b); - round::<49>(s, b); - round::<50>(s, b); - round::<51>(s, b); - round::<52>(s, b); - round::<53>(s, b); - round::<54>(s, b); - round::<55>(s, b); - round::<56>(s, b); - round::<57>(s, b); - round::<58>(s, b); - round::<59>(s, b); - round::<60>(s, b); - round::<61>(s, b); - round::<62>(s, b); - round::<63>(s, b); + for i in 0..3 { + let k = &K32[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K32[48..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); } } -#[inline(never)] -fn compress_aligned(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks { - let block = super::riscv_zknh_utils::load_aligned_block(block); - compress_block(state, block); - } -} - -#[cold] -fn compress_unaligned(state: &mut [u32; 8], blocks: &[[u8; 64]]) { +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { for block in blocks { - let block = super::riscv_zknh_utils::load_unaligned_block(block); + let block = if block.as_ptr().cast::().is_aligned() { + super::riscv_zknh_utils::load_aligned_block(block) + } else { + super::riscv_zknh_utils::load_unaligned_block(block) + }; compress_block(state, block); } } - -pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - if blocks.as_ptr().cast::().is_aligned() { - compress_aligned(state, blocks); - } else { - compress_unaligned(state, blocks); - } -} diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 621e02648..26aa2b59c 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -46,7 +46,7 @@ fn maj(x: u64, y: u64, z: u64) -> u64 { (x & y) ^ (x & z) ^ (y & z) } -fn round(state: &mut [u64; 8], block: &[u64; 16]) { +fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { let n = K64.len() - R; #[allow(clippy::identity_op)] let a = (n + 0) % 8; @@ -62,18 +62,18 @@ fn round(state: &mut [u64; 8], block: &[u64; 16]) { .wrapping_add(unsafe { sha512sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&K64[R]) }) - .wrapping_add(block[R % 16]); + .wrapping_add(unsafe { core::ptr::read_volatile(&k[R]) }) + .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] .wrapping_add(unsafe { sha512sum0(state[a]) }) .wrapping_add(maj(state[a], state[b], state[c])) } -fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16]) { - round::(state, block); +fn round_schedule(state: &mut [u64; 8], block: &mut [u64; 16], k: &[u64]) { + round::(state, block, k); - block[R % 16] = block[R % 16] + block[R] = block[R] .wrapping_add(unsafe { sha512sig1(block[(R + 14) % 16]) }) .wrapping_add(block[(R + 9) % 16]) .wrapping_add(unsafe { sha512sig0(block[(R + 1) % 16]) }); @@ -83,112 +83,56 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { let s = &mut state.clone(); let b = &mut block; - round_schedule::<0>(s, b); - round_schedule::<1>(s, b); - round_schedule::<2>(s, b); - round_schedule::<3>(s, b); - round_schedule::<4>(s, b); - round_schedule::<5>(s, b); - round_schedule::<6>(s, b); - round_schedule::<7>(s, b); - round_schedule::<8>(s, b); - round_schedule::<9>(s, b); - round_schedule::<10>(s, b); - round_schedule::<11>(s, b); - round_schedule::<12>(s, b); - round_schedule::<13>(s, b); - round_schedule::<14>(s, b); - round_schedule::<15>(s, b); - round_schedule::<16>(s, b); - round_schedule::<17>(s, b); - round_schedule::<18>(s, b); - round_schedule::<19>(s, b); - round_schedule::<20>(s, b); - round_schedule::<21>(s, b); - round_schedule::<22>(s, b); - round_schedule::<23>(s, b); - round_schedule::<24>(s, b); - round_schedule::<25>(s, b); - round_schedule::<26>(s, b); - round_schedule::<27>(s, b); - round_schedule::<28>(s, b); - round_schedule::<29>(s, b); - round_schedule::<30>(s, b); - round_schedule::<31>(s, b); - round_schedule::<32>(s, b); - round_schedule::<33>(s, b); - round_schedule::<34>(s, b); - round_schedule::<35>(s, b); - round_schedule::<36>(s, b); - round_schedule::<37>(s, b); - round_schedule::<38>(s, b); - round_schedule::<39>(s, b); - round_schedule::<40>(s, b); - round_schedule::<41>(s, b); - round_schedule::<42>(s, b); - round_schedule::<43>(s, b); - round_schedule::<44>(s, b); - round_schedule::<45>(s, b); - round_schedule::<46>(s, b); - round_schedule::<47>(s, b); - round_schedule::<48>(s, b); - round_schedule::<49>(s, b); - round_schedule::<50>(s, b); - round_schedule::<51>(s, b); - round_schedule::<52>(s, b); - round_schedule::<53>(s, b); - round_schedule::<54>(s, b); - round_schedule::<55>(s, b); - round_schedule::<56>(s, b); - round_schedule::<57>(s, b); - round_schedule::<58>(s, b); - round_schedule::<59>(s, b); - round_schedule::<60>(s, b); - round_schedule::<61>(s, b); - round_schedule::<62>(s, b); - round_schedule::<63>(s, b); - round::<64>(s, b); - round::<65>(s, b); - round::<66>(s, b); - round::<67>(s, b); - round::<68>(s, b); - round::<69>(s, b); - round::<70>(s, b); - round::<71>(s, b); - round::<72>(s, b); - round::<73>(s, b); - round::<74>(s, b); - round::<75>(s, b); - round::<76>(s, b); - round::<77>(s, b); - round::<78>(s, b); - round::<79>(s, b); + for i in 0..4 { + let k = &K64[16 * i..]; + round_schedule::<0>(s, b, k); + round_schedule::<1>(s, b, k); + round_schedule::<2>(s, b, k); + round_schedule::<3>(s, b, k); + round_schedule::<4>(s, b, k); + round_schedule::<5>(s, b, k); + round_schedule::<6>(s, b, k); + round_schedule::<7>(s, b, k); + round_schedule::<8>(s, b, k); + round_schedule::<9>(s, b, k); + round_schedule::<10>(s, b, k); + round_schedule::<11>(s, b, k); + round_schedule::<12>(s, b, k); + round_schedule::<13>(s, b, k); + round_schedule::<14>(s, b, k); + round_schedule::<15>(s, b, k); + } + + let k = &K64[64..]; + round::<0>(s, b, k); + round::<1>(s, b, k); + round::<2>(s, b, k); + round::<3>(s, b, k); + round::<4>(s, b, k); + round::<5>(s, b, k); + round::<6>(s, b, k); + round::<7>(s, b, k); + round::<8>(s, b, k); + round::<9>(s, b, k); + round::<10>(s, b, k); + round::<11>(s, b, k); + round::<12>(s, b, k); + round::<13>(s, b, k); + round::<14>(s, b, k); + round::<15>(s, b, k); for i in 0..8 { state[i] = state[i].wrapping_add(s[i]); } } -#[inline(never)] -fn compress_aligned(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks { - let block = super::riscv_zknh_utils::load_aligned_block(block); - compress_block(state, block); - } -} - -#[cold] -fn compress_unaligned(state: &mut [u64; 8], blocks: &[[u8; 128]]) { +pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { for block in blocks { - let block = super::riscv_zknh_utils::load_unaligned_block(block); + let block = if block.as_ptr().cast::().is_aligned() { + super::riscv_zknh_utils::load_aligned_block(block) + } else { + super::riscv_zknh_utils::load_unaligned_block(block) + }; compress_block(state, block); } } - -pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - if blocks.as_ptr().cast::().is_aligned() { - compress_aligned(state, blocks); - } else { - compress_unaligned(state, blocks); - } -} From 18db6234c1d72808ee5a616271622444b7b19b5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 19:30:39 +0300 Subject: [PATCH 03/10] relax target feature requirements --- sha2/src/sha256/riscv_zknh.rs | 7 +++++-- sha2/src/sha256/riscv_zknh_compact.rs | 7 +++++-- sha2/src/sha512/riscv_zknh.rs | 7 +++++-- sha2/src/sha512/riscv_zknh_compact.rs | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index bc73faad0..5b210f426 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] -compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index cf66ab5c3..bfc6157e5 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] -compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 26aa2b59c..81e66dad1 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] -compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs index 6ecb39875..170a5d593 100644 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ b/sha2/src/sha512/riscv_zknh_compact.rs @@ -5,8 +5,11 @@ use core::arch::riscv32::*; #[cfg(target_arch = "riscv64")] use core::arch::riscv64::*; -#[cfg(not(all(target_feature = "zknh", target_feature = "zbkb")))] -compile_error!("riscv-zknh-compact backend requires enabled zknh and zbkb target features"); +#[cfg(not(all( + target_feature = "zknh", + any(target_feature = "zbb", target_feature = "zbkb") +)))] +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[cfg(target_arch = "riscv32")] unsafe fn sha512sum0(x: u64) -> u64 { From 0acd099102ced9c17051fa533add42abf55bdc68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 19:54:12 +0300 Subject: [PATCH 04/10] Fix compile error cfg --- sha2/src/sha256/riscv_zknh_compact.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index bfc6157e5..c26b924b8 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -9,7 +9,7 @@ use core::arch::riscv64::*; target_feature = "zknh", any(target_feature = "zbb", target_feature = "zbkb") )))] -compile_error!("riscv-zknh backend requires zknh and zbkb (or zbb) target features"); +compile_error!("riscv-zknh-compact backend requires zknh and zbkb (or zbb) target features"); #[inline(always)] fn ch(x: u32, y: u32, z: u32) -> u32 { From 2326f77b27b32e6c58bd0afa047be9b759b2aa2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 19:55:09 +0300 Subject: [PATCH 05/10] fix CI --- .github/workflows/sha2.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/sha2.yml b/.github/workflows/sha2.yml index 3efb039ef..fdb9e9a00 100644 --- a/.github/workflows/sha2.yml +++ b/.github/workflows/sha2.yml @@ -169,13 +169,13 @@ jobs: - run: cargo install cross --git https://github.com/cross-rs/cross - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cross test --package sha2 --all-features --target riscv64gc-unknown-linux-gnu env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb riscv32-zknh: runs-on: ubuntu-latest @@ -188,13 +188,13 @@ jobs: components: rust-src - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="soft" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh" -C target-feature=+zknh,+zbkb - run: cargo build --all-features --target riscv32gc-unknown-linux-gnu -Z build-std env: - RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh + RUSTFLAGS: -Dwarnings --cfg sha2_backend="riscv-zknh-compact" -C target-feature=+zknh,+zbkb minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master From 5ad4f5859445f180c616b02c6220fd4767d7de03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 20:07:04 +0300 Subject: [PATCH 06/10] Use asm!-based opaque load instead of volatile read --- sha2/src/sha256/riscv_zknh.rs | 19 +++++++++++++++++-- sha2/src/sha512/riscv_zknh.rs | 19 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index 5b210f426..ea06d9481 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -21,6 +21,22 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (x & z) ^ (y & z) } +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u32]) -> u32 { + assert!(R < k.len()); + let dst; + unsafe { + core::arch::asm!( + "lw {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + dst +} + fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { let n = K32.len() - R; #[allow(clippy::identity_op)] @@ -36,8 +52,7 @@ fn round(state: &mut [u32; 8], block: &[u32; 16], k: &[u32]) { state[h] = state[h] .wrapping_add(unsafe { sha256sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&k[R]) }) + .wrapping_add(opaque_load::(k)) .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 81e66dad1..e8e0a1319 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -49,6 +49,22 @@ fn maj(x: u64, y: u64, z: u64) -> u64 { (x & y) ^ (x & z) ^ (y & z) } +/// This function returns `k[R]`, but prevents compiler from inlining the indexed value +pub(super) fn opaque_load(k: &[u64]) -> u64 { + assert!(R < k.len()); + let dst; + unsafe { + core::arch::asm!( + "ld {dst}, 8*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + dst +} + fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { let n = K64.len() - R; #[allow(clippy::identity_op)] @@ -64,8 +80,7 @@ fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { state[h] = state[h] .wrapping_add(unsafe { sha512sum1(state[e]) }) .wrapping_add(ch(state[e], state[f], state[g])) - // Force reading of constants from the static to prevent bad codegen - .wrapping_add(unsafe { core::ptr::read_volatile(&k[R]) }) + .wrapping_add(opaque_load::(k)) .wrapping_add(block[R]); state[d] = state[d].wrapping_add(state[h]); state[h] = state[h] From 87c8b8779a2df01b4f5ee73ff486bde86b0952ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 20:45:43 +0300 Subject: [PATCH 07/10] fix opaque load --- sha2/src/sha256/riscv_zknh.rs | 11 +++++++++++ sha2/src/sha512/riscv_zknh.rs | 27 ++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index ea06d9481..17b058a2d 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -25,6 +25,17 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { pub(super) fn opaque_load(k: &[u32]) -> u32 { assert!(R < k.len()); let dst; + #[cfg(target_arch = "riscv64")] + unsafe { + core::arch::asm!( + "lwu {dst}, 4*{R}({k})", + R = const R, + k = in(reg) k.as_ptr(), + dst = out(reg) dst, + options(pure, readonly, nostack, preserves_flags), + ); + } + #[cfg(target_arch = "riscv32")] unsafe { core::arch::asm!( "lw {dst}, 4*{R}({k})", diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index e8e0a1319..cd6c5a71a 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -51,18 +51,35 @@ fn maj(x: u64, y: u64, z: u64) -> u64 { /// This function returns `k[R]`, but prevents compiler from inlining the indexed value pub(super) fn opaque_load(k: &[u64]) -> u64 { + use core::arch::asm; assert!(R < k.len()); - let dst; + #[cfg(target_arch = "riscv64")] unsafe { - core::arch::asm!( - "ld {dst}, 8*{R}({k})", - R = const R, + let dst; + asm!( + "ld {dst}, {N}({k})", + N = const 8 * R, k = in(reg) k.as_ptr(), dst = out(reg) dst, options(pure, readonly, nostack, preserves_flags), ); + dst + } + #[cfg(target_arch = "riscv32")] + unsafe { + let [hi, lo]: [u32; 2]; + asm!( + "lwu {lo}, {N1}({k})", + "lwu {hi}, {N2}({k})", + N1 = const 8 * R, + N2 = const 8 * R + 4, + k = in(reg) k.as_ptr(), + lo = out(reg) lo, + hi = out(reg) hi, + options(pure, readonly, nostack, preserves_flags), + ); + ((hi as u64) << 32) | (lo as u64) } - dst } fn round(state: &mut [u64; 8], block: &[u64; 16], k: &[u64]) { From 8386abc8cae255f0f0e7c46fe783f9fd6fc72a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Mon, 26 Aug 2024 20:57:11 +0300 Subject: [PATCH 08/10] fix opaque load --- sha2/src/sha512/riscv_zknh.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index cd6c5a71a..675bed2c9 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -69,8 +69,8 @@ pub(super) fn opaque_load(k: &[u64]) -> u64 { unsafe { let [hi, lo]: [u32; 2]; asm!( - "lwu {lo}, {N1}({k})", - "lwu {hi}, {N2}({k})", + "lw {lo}, {N1}({k})", + "lw {hi}, {N2}({k})", N1 = const 8 * R, N2 = const 8 * R + 4, k = in(reg) k.as_ptr(), From 1dec071608858ababa56d1042f248e3699de1141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Tue, 27 Aug 2024 12:28:37 +0300 Subject: [PATCH 09/10] Expose only `load_block` from the util modules --- sha2/src/sha256/riscv_zknh.rs | 7 +--- sha2/src/sha256/riscv_zknh_compact.rs | 7 +--- sha2/src/sha256/riscv_zknh_utils.rs | 55 ++++++++++++++------------- sha2/src/sha512/riscv_zknh.rs | 7 +--- sha2/src/sha512/riscv_zknh_compact.rs | 7 +--- sha2/src/sha512/riscv_zknh_utils.rs | 37 ++++++++++-------- 6 files changed, 54 insertions(+), 66 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index 17b058a2d..5cef6b4b7 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -129,12 +129,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks { - let block = if block.as_ptr().cast::().is_aligned() { - super::riscv_zknh_utils::load_aligned_block(block) - } else { - super::riscv_zknh_utils::load_unaligned_block(block) - }; + for block in blocks.map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_compact.rs b/sha2/src/sha256/riscv_zknh_compact.rs index c26b924b8..bba510a30 100644 --- a/sha2/src/sha256/riscv_zknh_compact.rs +++ b/sha2/src/sha256/riscv_zknh_compact.rs @@ -70,12 +70,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.iter() { - let block = if block.as_ptr().cast::().is_aligned() { - super::riscv_zknh_utils::load_aligned_block(block) - } else { - super::riscv_zknh_utils::load_unaligned_block(block) - }; + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha256/riscv_zknh_utils.rs b/sha2/src/sha256/riscv_zknh_utils.rs index 903220156..d75a0b1c1 100644 --- a/sha2/src/sha256/riscv_zknh_utils.rs +++ b/sha2/src/sha256/riscv_zknh_utils.rs @@ -1,7 +1,16 @@ use core::{arch::asm, ptr}; #[inline(always)] -pub(super) fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { +pub(super) fn load_block(block: &[u8; 64]) -> [u32; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[inline(always)] +fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { let p: *const u32 = block.as_ptr().cast(); debug_assert!(p.is_aligned()); let mut res = [0u32; 16]; @@ -13,7 +22,7 @@ pub(super) fn load_aligned_block(block: &[u8; 64]) -> [u32; 16] { } #[inline(always)] -pub(super) fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { +fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { let offset = (block.as_ptr() as usize) % align_of::(); debug_assert_ne!(offset, 0); let off1 = (8 * offset) % 32; @@ -23,20 +32,24 @@ pub(super) fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { let mut left: u32; let mut res = [0u32; 16]; + /// Use LW instruction on RV32 and LWU on RV64 + #[cfg(target_arch = "riscv32")] + macro_rules! lw { + ($r:literal) => { + concat!("lw ", $r) + }; + } + #[cfg(target_arch = "riscv64")] + macro_rules! lw { + ($r:literal) => { + concat!("lwu ", $r) + }; + } + unsafe { - #[cfg(target_arch = "riscv64")] asm!( - "lwu {left}, 0({bp})", - "srl {left}, {left}, {off1}", - bp = in(reg) bp, - off1 = in(reg) off1, - left = out(reg) left, - options(pure, nostack, readonly, preserves_flags), - ); - #[cfg(target_arch = "riscv32")] - asm!( - "lw {left}, 0({bp})", - "srl {left}, {left}, {off1}", + lw!("{left}, 0({bp})"), // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; bp = in(reg) bp, off1 = in(reg) off1, left = out(reg) left, @@ -52,19 +65,9 @@ pub(super) fn load_unaligned_block(block: &[u8; 64]) -> [u32; 16] { let right: u32; unsafe { - #[cfg(target_arch = "riscv64")] - asm!( - "lwu {right}, 64({bp})", - "sll {right}, {right}, {off2}", - bp = in(reg) bp, - off2 = in(reg) off2, - right = out(reg) right, - options(pure, nostack, readonly, preserves_flags), - ); - #[cfg(target_arch = "riscv32")] asm!( - "lw {right}, 64({bp})", - "sll {right}, {right}, {off2}", + lw!("{right}, 16 * 4({bp})"), // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; bp = in(reg) bp, off2 = in(reg) off2, right = out(reg) right, diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index 675bed2c9..cf44eea14 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -162,12 +162,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks { - let block = if block.as_ptr().cast::().is_aligned() { - super::riscv_zknh_utils::load_aligned_block(block) - } else { - super::riscv_zknh_utils::load_unaligned_block(block) - }; + for block in blocks.map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_compact.rs b/sha2/src/sha512/riscv_zknh_compact.rs index 170a5d593..729157c45 100644 --- a/sha2/src/sha512/riscv_zknh_compact.rs +++ b/sha2/src/sha512/riscv_zknh_compact.rs @@ -98,12 +98,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.iter() { - let block = if block.as_ptr().cast::().is_aligned() { - super::riscv_zknh_utils::load_aligned_block(block) - } else { - super::riscv_zknh_utils::load_unaligned_block(block) - }; + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh_utils.rs b/sha2/src/sha512/riscv_zknh_utils.rs index d1a7d7549..0b4746069 100644 --- a/sha2/src/sha512/riscv_zknh_utils.rs +++ b/sha2/src/sha512/riscv_zknh_utils.rs @@ -1,8 +1,16 @@ use core::{arch::asm, ptr}; -#[cfg(target_arch = "riscv32")] #[inline(always)] -pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { +pub(super) fn load_block(block: &[u8; 128]) -> [u64; 16] { + if block.as_ptr().cast::().is_aligned() { + load_aligned_block(block) + } else { + load_unaligned_block(block) + } +} + +#[cfg(target_arch = "riscv32")] +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { let p: *const [u32; 32] = block.as_ptr().cast(); debug_assert!(p.is_aligned()); let block = unsafe { &*p }; @@ -16,8 +24,7 @@ pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { } #[cfg(target_arch = "riscv64")] -#[inline(always)] -pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { +fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { let block_ptr: *const u64 = block.as_ptr().cast(); debug_assert!(block_ptr.is_aligned()); let mut res = [0u64; 16]; @@ -29,8 +36,7 @@ pub(super) fn load_aligned_block(block: &[u8; 128]) -> [u64; 16] { } #[cfg(target_arch = "riscv32")] -#[inline(always)] -pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { let offset = (block.as_ptr() as usize) % align_of::(); debug_assert_ne!(offset, 0); let off1 = (8 * offset) % 32; @@ -42,8 +48,8 @@ pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { unsafe { asm!( - "lw {left}, 0({bp})", - "srl {left}, {left}, {off1}", + "lw {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; bp = in(reg) bp, off1 = in(reg) off1, left = out(reg) left, @@ -60,8 +66,8 @@ pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { let right: u32; unsafe { asm!( - "lw {right}, 128({bp})", - "sll {right}, {right}, {off2}", + "lw {right}, 32 * 4({bp})", // right = ptr::read(bp.add(32)); + "sll {right}, {right}, {off2}", // right <<= off2; bp = in(reg) bp, off2 = in(reg) off2, right = out(reg) right, @@ -80,8 +86,7 @@ pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { } #[cfg(target_arch = "riscv64")] -#[inline(always)] -pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { +fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { let offset = (block.as_ptr() as usize) % align_of::(); debug_assert_ne!(offset, 0); let off1 = (8 * offset) % 64; @@ -93,8 +98,8 @@ pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { unsafe { asm!( - "ld {left}, 0({bp})", - "srl {left}, {left}, {off1}", + "ld {left}, 0({bp})", // left = unsafe { ptr::read(bp) }; + "srl {left}, {left}, {off1}", // left >>= off1; bp = in(reg) bp, off1 = in(reg) off1, left = out(reg) left, @@ -110,8 +115,8 @@ pub(super) fn load_unaligned_block(block: &[u8; 128]) -> [u64; 16] { let right: u64; unsafe { asm!( - "ld {right}, 128({bp})", - "sll {right}, {right}, {off2}", + "ld {right}, 16 * 8({bp})", // right = ptr::read(bp.add(16)); + "sll {right}, {right}, {off2}", // right <<= off2; bp = in(reg) bp, off2 = in(reg) off2, right = out(reg) right, From 88bae593de260053edd6ca1c52184dc859432699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Tue, 27 Aug 2024 12:32:47 +0300 Subject: [PATCH 10/10] fix --- sha2/src/sha256/riscv_zknh.rs | 2 +- sha2/src/sha512/riscv_zknh.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sha2/src/sha256/riscv_zknh.rs b/sha2/src/sha256/riscv_zknh.rs index 5cef6b4b7..7477c6409 100644 --- a/sha2/src/sha256/riscv_zknh.rs +++ b/sha2/src/sha256/riscv_zknh.rs @@ -129,7 +129,7 @@ fn compress_block(state: &mut [u32; 8], mut block: [u32; 16]) { } pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { - for block in blocks.map(super::riscv_zknh_utils::load_block) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } } diff --git a/sha2/src/sha512/riscv_zknh.rs b/sha2/src/sha512/riscv_zknh.rs index cf44eea14..7be35ee84 100644 --- a/sha2/src/sha512/riscv_zknh.rs +++ b/sha2/src/sha512/riscv_zknh.rs @@ -162,7 +162,7 @@ fn compress_block(state: &mut [u64; 8], mut block: [u64; 16]) { } pub fn compress(state: &mut [u64; 8], blocks: &[[u8; 128]]) { - for block in blocks.map(super::riscv_zknh_utils::load_block) { + for block in blocks.iter().map(super::riscv_zknh_utils::load_block) { compress_block(state, block); } }