From 017d9c1a862ce5da80cda1d5c41d1ac1e1aab178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Sun, 10 May 2026 15:24:48 +0200 Subject: [PATCH 1/3] =?UTF-8?q?chore:=20add=20tests=20using=20khal?= =?UTF-8?q?=E2=80=99s=20new=20metal=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 1 + src/linalg/contiguous.rs | 8 ++++++++ src/linalg/gemm.rs | 8 ++++++++ src/linalg/op_assign.rs | 8 ++++++++ src/linalg/reduce.rs | 8 ++++++++ 5 files changed, 33 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index c888def..6eebd8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ webgpu = [ "khal/webgpu" ] cpu = ["khal/cpu", "vortx-shaders/cpu"] cpu-parallel = ["cpu", "vortx-shaders/cpu-parallel"] cuda = ["khal/cuda", "khal-builder/cuda", "vortx-shaders/cuda"] +metal = ["khal/metal"] push_constants = ["khal/push_constants", "vortx-shaders/push_constants"] subgroup_ops = ["khal/subgroup_ops", "vortx-shaders/subgroup_ops"] diff --git a/src/linalg/contiguous.rs b/src/linalg/contiguous.rs index dc7cb3d..8c5289f 100644 --- a/src/linalg/contiguous.rs +++ b/src/linalg/contiguous.rs @@ -151,6 +151,14 @@ mod test { gpu_contiguous_generic(&cuda).await; } + #[cfg(feature = "metal")] + #[futures_test::test] + #[serial_test::serial] + async fn gpu_contiguous_metal() { + let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap()); + gpu_contiguous_generic(&metal).await; + } + async fn gpu_contiguous_generic(backend: &GpuBackend) { let contiguous = super::Contiguous::from_backend(backend).unwrap(); diff --git a/src/linalg/gemm.rs b/src/linalg/gemm.rs index e68e7ce..cd43867 100644 --- a/src/linalg/gemm.rs +++ b/src/linalg/gemm.rs @@ -354,6 +354,14 @@ mod test { gpu_gemm_generic(&cuda).await; } + #[cfg(feature = "metal")] + #[futures_test::test] + #[serial_test::serial] + async fn gpu_gemm_metal() { + let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap()); + gpu_gemm_generic(&metal).await; + } + async fn gpu_gemm_generic(backend: &GpuBackend) { let gemm = super::Gemm::from_backend(backend).unwrap(); diff --git a/src/linalg/op_assign.rs b/src/linalg/op_assign.rs index cb2f6fb..4f45e48 100644 --- a/src/linalg/op_assign.rs +++ b/src/linalg/op_assign.rs @@ -221,6 +221,14 @@ mod test { gpu_op_assign_with_backend(&cuda).await; } + #[cfg(feature = "metal")] + #[futures_test::test] + #[serial_test::serial] + async fn gpu_op_assign_metal() { + let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap()); + gpu_op_assign_with_backend(&metal).await; + } + async fn gpu_op_assign_with_backend(backend: &GpuBackend) { let ops = [ OpAssignVariant::Add, diff --git a/src/linalg/reduce.rs b/src/linalg/reduce.rs index bfad8c5..0da7e36 100644 --- a/src/linalg/reduce.rs +++ b/src/linalg/reduce.rs @@ -218,6 +218,14 @@ mod test { gpu_reduce_generic(&cuda).await; } + #[cfg(feature = "metal")] + #[futures_test::test] + #[serial_test::serial] + async fn gpu_reduce_metal() { + let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap()); + gpu_reduce_generic(&metal).await; + } + async fn gpu_reduce_generic(backend: &GpuBackend) { let ops = [ ReduceVariant::Min, From 686f80dc9be0cf90f651c4a988dedc4c86d2bc93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Fri, 15 May 2026 10:57:12 +0200 Subject: [PATCH 2/3] feat: use khal new helpers target_arch_is_gpu and setup_shader_crate_build --- vortx-shaders/Cargo.toml | 5 ++++- vortx-shaders/build.rs | 7 +------ vortx-shaders/src/lib.rs | 2 +- vortx-shaders/src/linalg/mod.rs | 10 +++++----- vortx-shaders/src/linalg/op_assign.rs | 2 +- vortx-shaders/src/linalg/shape.rs | 8 ++++---- vortx-shaders/src/utils/trig.rs | 2 +- 7 files changed, 17 insertions(+), 19 deletions(-) diff --git a/vortx-shaders/Cargo.toml b/vortx-shaders/Cargo.toml index 998d75f..72fc070 100644 --- a/vortx-shaders/Cargo.toml +++ b/vortx-shaders/Cargo.toml @@ -28,7 +28,10 @@ khal-std = { workspace = true } # glamx provides UVec3 and other glam types (no_std compatible, used on all targets). glamx = { version = "0.2", default-features = false, features = ["nostd-libm", "bytemuck"] } -# Host-only dependencies (excluded on GPU targets: spirv and nvptx64). +[build-dependencies] +khal-std = { workspace = true } + +# Host-only dependencies (excluded on GPU targets: spirv, nvptx64). [target.'cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))'.dependencies] khal = { workspace = true } bytemuck = { version = "1", features = ["derive"] } diff --git a/vortx-shaders/build.rs b/vortx-shaders/build.rs index 8016999..d50740c 100644 --- a/vortx-shaders/build.rs +++ b/vortx-shaders/build.rs @@ -1,8 +1,3 @@ -// Re-exports this crate's source location to host crates that build the -// shaders. fn main() { - let manifest_dir = - std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set by cargo"); - println!("cargo::metadata=manifest_dir={manifest_dir}"); - println!("cargo:rerun-if-changed=build.rs"); + khal_std::setup_shader_crate_build(); } diff --git a/vortx-shaders/src/lib.rs b/vortx-shaders/src/lib.rs index 6eac059..6ab2b13 100644 --- a/vortx-shaders/src/lib.rs +++ b/vortx-shaders/src/lib.rs @@ -8,7 +8,7 @@ #![allow(clippy::too_many_arguments)] // Enable std on host for generated ShaderArgs structs (not on GPU targets). -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] extern crate std; pub mod linalg; diff --git a/vortx-shaders/src/linalg/mod.rs b/vortx-shaders/src/linalg/mod.rs index dd0b36b..01f73a3 100644 --- a/vortx-shaders/src/linalg/mod.rs +++ b/vortx-shaders/src/linalg/mod.rs @@ -13,13 +13,13 @@ pub use shape::Shape; pub use shape::{Shapes1, Shapes2, Shapes3}; // Re-export generated ShaderArgs structs (only available on host) -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] pub use contiguous::{Contiguous, ContiguousWithOffset}; -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] pub use gemm::{GemmNaive, GemmTiled}; -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] pub use op_assign::{GpuAdd, GpuCopy, GpuCopyWithOffsets, GpuDiv, GpuMul, GpuSub}; -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] pub use reduce::{ReduceAdd, ReduceMax, ReduceMin, ReduceMul, ReduceSqNorm}; -#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +#[cfg(not(target_arch_is_gpu))] pub use repeat::Repeat; diff --git a/vortx-shaders/src/linalg/op_assign.rs b/vortx-shaders/src/linalg/op_assign.rs index 0402b9f..ae7f76b 100644 --- a/vortx-shaders/src/linalg/op_assign.rs +++ b/vortx-shaders/src/linalg/op_assign.rs @@ -17,7 +17,7 @@ const MAX_NUM_THREADS: u32 = MAX_NUM_WORKGROUPS * WORKGROUP_SIZE; #[repr(C)] #[derive(Clone, Copy)] #[cfg_attr( - not(any(target_arch = "spirv", target_arch = "nvptx64")), + not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable) )] pub struct BinOpOffsets { diff --git a/vortx-shaders/src/linalg/shape.rs b/vortx-shaders/src/linalg/shape.rs index 369f335..d1ffa11 100644 --- a/vortx-shaders/src/linalg/shape.rs +++ b/vortx-shaders/src/linalg/shape.rs @@ -9,7 +9,7 @@ use glamx::UVec4; #[repr(C)] #[derive(Clone, Copy)] #[cfg_attr( - not(any(target_arch = "spirv", target_arch = "nvptx64")), + not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable) )] pub struct Shape { @@ -104,7 +104,7 @@ pub fn div_ceil4(a: u32) -> u32 { #[repr(C)] #[derive(Clone, Copy)] #[cfg_attr( - not(any(target_arch = "spirv", target_arch = "nvptx64")), + not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable) )] pub struct Shapes2 { @@ -119,7 +119,7 @@ pub struct Shapes2 { #[repr(C)] #[derive(Clone, Copy)] #[cfg_attr( - not(any(target_arch = "spirv", target_arch = "nvptx64")), + not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable) )] pub struct Shapes3 { @@ -136,7 +136,7 @@ pub struct Shapes3 { #[repr(C)] #[derive(Clone, Copy)] #[cfg_attr( - not(any(target_arch = "spirv", target_arch = "nvptx64")), + not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable) )] pub struct Shapes1 { diff --git a/vortx-shaders/src/utils/trig.rs b/vortx-shaders/src/utils/trig.rs index 40fc6e5..ae9d212 100644 --- a/vortx-shaders/src/utils/trig.rs +++ b/vortx-shaders/src/utils/trig.rs @@ -1,6 +1,6 @@ //! Trigonometric utility functions. -#[cfg(any(target_arch = "spirv", target_arch = "nvptx64"))] +#[cfg(target_arch_is_gpu)] use khal_std::num_traits::Float; /// The value of pi. From bd446615bc717620e311d54fa2d60d5106f7a595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Fri, 29 May 2026 18:15:21 +0200 Subject: [PATCH 3/3] feat: switch to khal 0.2 + nalgebra 0.35 --- Cargo.toml | 13 +++++++------ vortx-shaders/Cargo.toml | 2 +- vortx-shaders/src/linalg/op_assign.rs | 5 +---- vortx-shaders/src/linalg/shape.rs | 20 ++++---------------- 4 files changed, 13 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6eebd8f..305fa82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,20 +28,20 @@ subgroup_ops = ["khal/subgroup_ops", "vortx-shaders/subgroup_ops"] version = "0.1.1" [workspace.dependencies] -khal-std = "0.1" -khal = { version = "0.1", features = ["derive"]} +khal-std = "0.2" +khal = { version = "0.2", features = ["derive"]} [dependencies] bytemuck = "1" include_dir = "0.7" -nalgebra = "0.34" +nalgebra = "0.35" khal = { workspace = true } khal-std = { workspace = true } # Shader crate provides both GPU shader code and generated ShaderArgs via spirv_bindgen vortx-shaders = { version = "0.1", path = "vortx-shaders" } [dev-dependencies] -nalgebra = { version = "0.34", features = ["rand"] } +nalgebra = { version = "0.35", features = ["rand"] } futures-test = "0.3" serial_test = "3" approx = "0.5" @@ -52,12 +52,13 @@ anyhow = "1" wgpu = "29" [build-dependencies] -khal-builder = "0.1.1" +khal-builder = "0.2" # To build the shader from the dependency instead of local path. vortx-shaders = { version = "0.1.1", path = "./vortx-shaders" } -#[patch.crates-io] +[patch.crates-io] #khal-builder = { path = "../khal/crates/khal-builder" } #khal = { path = "../khal/crates/khal" } #khal-std = { path = "../khal/crates/khal-std" } +#khal-derive = { path = "../khal/crates/khal-derive" } #glamx = { git = "https://github.com/dimforge/glamx", branch = "bytemuck" } diff --git a/vortx-shaders/Cargo.toml b/vortx-shaders/Cargo.toml index 72fc070..3ababf5 100644 --- a/vortx-shaders/Cargo.toml +++ b/vortx-shaders/Cargo.toml @@ -26,7 +26,7 @@ cuda = ["khal-std/cuda", "khal/cuda"] [dependencies] khal-std = { workspace = true } # glamx provides UVec3 and other glam types (no_std compatible, used on all targets). -glamx = { version = "0.2", default-features = false, features = ["nostd-libm", "bytemuck"] } +glamx = { version = "0.3", default-features = false, features = ["nostd-libm", "bytemuck"] } [build-dependencies] khal-std = { workspace = true } diff --git a/vortx-shaders/src/linalg/op_assign.rs b/vortx-shaders/src/linalg/op_assign.rs index ae7f76b..7bc134e 100644 --- a/vortx-shaders/src/linalg/op_assign.rs +++ b/vortx-shaders/src/linalg/op_assign.rs @@ -16,10 +16,7 @@ const MAX_NUM_THREADS: u32 = MAX_NUM_WORKGROUPS * WORKGROUP_SIZE; /// Binary operation offsets. #[repr(C)] #[derive(Clone, Copy)] -#[cfg_attr( - not(target_arch_is_gpu), - derive(bytemuck::Pod, bytemuck::Zeroable) -)] +#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))] pub struct BinOpOffsets { pub a: u32, pub b: u32, diff --git a/vortx-shaders/src/linalg/shape.rs b/vortx-shaders/src/linalg/shape.rs index d1ffa11..fa9a3aa 100644 --- a/vortx-shaders/src/linalg/shape.rs +++ b/vortx-shaders/src/linalg/shape.rs @@ -8,10 +8,7 @@ use glamx::UVec4; /// (Samples, Channels, Height, Width), where height is the row count, and width the column count. #[repr(C)] #[derive(Clone, Copy)] -#[cfg_attr( - not(target_arch_is_gpu), - derive(bytemuck::Pod, bytemuck::Zeroable) -)] +#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))] pub struct Shape { /// Number of rows in each matrix of the tensor. pub n: u32, @@ -103,10 +100,7 @@ pub fn div_ceil4(a: u32) -> u32 { #[cfg(feature = "push_constants")] #[repr(C)] #[derive(Clone, Copy)] -#[cfg_attr( - not(target_arch_is_gpu), - derive(bytemuck::Pod, bytemuck::Zeroable) -)] +#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))] pub struct Shapes2 { /// First shape (typically output or left operand). pub shape_a: Shape, @@ -118,10 +112,7 @@ pub struct Shapes2 { #[cfg(feature = "push_constants")] #[repr(C)] #[derive(Clone, Copy)] -#[cfg_attr( - not(target_arch_is_gpu), - derive(bytemuck::Pod, bytemuck::Zeroable) -)] +#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))] pub struct Shapes3 { /// Output shape. pub shape_out: Shape, @@ -135,10 +126,7 @@ pub struct Shapes3 { #[cfg(feature = "push_constants")] #[repr(C)] #[derive(Clone, Copy)] -#[cfg_attr( - not(target_arch_is_gpu), - derive(bytemuck::Pod, bytemuck::Zeroable) -)] +#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))] pub struct Shapes1 { /// The shape. pub shape: Shape,