diff --git a/crates/khal-std/Cargo.toml b/crates/khal-std/Cargo.toml index e60ec2f..0a4e29d 100644 --- a/crates/khal-std/Cargo.toml +++ b/crates/khal-std/Cargo.toml @@ -25,7 +25,7 @@ cuda = [] unsafe_remove_boundchecks = [] [dependencies] -glamx = { version = "0.2", default-features = false, features = ["nostd-libm", "bytemuck"] } +glamx = { version = "0.3", default-features = false, features = ["nostd-libm", "bytemuck", "u32", "i32", "f64"] } rayon = { version = "1", optional = true } corosensei = { version = "0.3", optional = true } spirv-std-macros = "0.10.0-alpha.1" diff --git a/crates/khal-std/src/build_script.rs b/crates/khal-std/src/build_script.rs new file mode 100644 index 0000000..8510b35 --- /dev/null +++ b/crates/khal-std/src/build_script.rs @@ -0,0 +1,40 @@ +//! Build-script helpers for shader crates. +//! +//! Host-only module (gated out of GPU-target builds). Intended to be called +//! from a shader crate's `build.rs` — see [`setup_shader_crate_build`]. + +/// Standard `build.rs` setup that every shader crate should run. +/// +/// Does three things: +/// +/// 1. Emits `cargo::metadata=manifest_dir=` so that host crates +/// consuming this shader crate via `KhalBuilder::from_dependency` +/// discover the shader sources both in-workspace and from +/// `crates.io`-fetched copies. +/// 2. Declares the `target_arch_is_gpu` cfg via `cargo::rustc-check-cfg` +/// so `#[cfg(target_arch_is_gpu)]` / `#[cfg(not(target_arch_is_gpu))]` +/// don't trip the `unexpected_cfgs` lint. +/// 3. Sets `target_arch_is_gpu` when compiling for any GPU target +/// (SPIR-V, NVPTX). The host CPU build sees it unset. +/// +/// Call from `build.rs`: +/// +/// ```no_run +/// khal_std::build_script::setup_shader_crate_build(); +/// ``` +/// +/// The shader crate must list `khal-std` as a `[build-dependencies]` entry +/// (in addition to its regular `[dependencies]` use). +pub fn setup_shader_crate_build() { + let manifest_dir = + std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set by cargo"); + println!("cargo::metadata=manifest_dir={manifest_dir}"); + println!("cargo:rerun-if-changed=build.rs"); + + println!("cargo::rustc-check-cfg=cfg(target_arch_is_gpu)"); + + let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default(); + if matches!(target_arch.as_str(), "spirv" | "nvptx64") { + println!("cargo::rustc-cfg=target_arch_is_gpu"); + } +} diff --git a/crates/khal-std/src/lib.rs b/crates/khal-std/src/lib.rs index c154e5f..dd39cbd 100644 --- a/crates/khal-std/src/lib.rs +++ b/crates/khal-std/src/lib.rs @@ -23,6 +23,12 @@ pub mod num_traits; /// Synchronization primitives (barriers, atomics). pub mod sync; +/// Build-script helpers for shader crates. Host-only. +#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +pub mod build_script; +#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))] +pub use build_script::*; + /// Re-export of the `glamx` math library. pub use glamx; diff --git a/crates/khal/Cargo.toml b/crates/khal/Cargo.toml index 8564d29..b23d4b5 100644 --- a/crates/khal/Cargo.toml +++ b/crates/khal/Cargo.toml @@ -13,6 +13,7 @@ derive = ["khal-derive"] webgpu = ["khal-derive?/webgpu"] cpu = ["khal-derive?/cpu"] cuda = ["dep:cudarc", "khal-derive?/cuda"] +metal = ["dep:metal", "dep:naga", "khal-derive?/webgpu"] push_constants = [] subgroup_ops = [] @@ -34,4 +35,8 @@ regex = "1" bitflags = "2" # For test_shader_compilation -paste = "1" \ No newline at end of file +paste = "1" + +[target.'cfg(target_os = "macos")'.dependencies] +metal = { version = "0.32", optional = true } +naga = { version = "29", optional = true, features = ["spv-in", "msl-out"] } \ No newline at end of file diff --git a/crates/khal/src/backend/any_backend.rs b/crates/khal/src/backend/any_backend.rs index 925a851..7ecc166 100644 --- a/crates/khal/src/backend/any_backend.rs +++ b/crates/khal/src/backend/any_backend.rs @@ -6,6 +6,12 @@ use crate::backend::cuda::{ CudaEncoder as CudaEncoderInner, CudaFunction as CudaFunctionInner, CudaModule as CudaModuleInner, CudaPass as CudaPassInner, CudaTimestamps, }; +#[cfg(feature = "metal")] +use crate::backend::metal::{ + Metal, MetalBackendError, MetalBuffer, MetalBufferSlice, MetalDispatch as MetalDispatchInner, + MetalEncoder as MetalEncoderInner, MetalFunction as MetalFunctionInner, + MetalModule as MetalModuleInner, MetalPass as MetalPassInner, MetalTimestamps, +}; #[cfg(feature = "webgpu")] use crate::backend::webgpu::CommandEncoderExt; use crate::backend::webgpu::WebGpuTimestamps; @@ -30,6 +36,8 @@ pub enum GpuBackend { WebGpu(WebGpu), #[cfg(feature = "cuda")] Cuda(Cuda), + #[cfg(feature = "metal")] + Metal(Metal), #[cfg(feature = "cpu")] Cpu, } @@ -42,6 +50,8 @@ impl GpuBackend { Self::WebGpu(_) => super::CompileTarget::Wgsl, #[cfg(feature = "cuda")] Self::Cuda(_) => super::CompileTarget::Ptx, + #[cfg(feature = "metal")] + Self::Metal(_) => super::CompileTarget::Spirv, #[cfg(feature = "cpu")] Self::Cpu => super::CompileTarget::Wgsl, } @@ -53,6 +63,12 @@ impl GpuBackend { matches!(self, Self::Cuda(..)) } + /// Returns `true` if this is the Metal backend. + #[cfg(feature = "metal")] + pub fn is_metal(&self) -> bool { + matches!(self, Self::Metal(..)) + } + /// Loads a SPIR-V module using passthrough loading (bypassing naga validation). /// /// On the WebGPU backend, this uses `create_shader_module_spirv` to pass raw SPIR-V @@ -73,6 +89,8 @@ impl GpuBackend { // For non-WebGPU backends, SPIR-V is already loaded natively. #[cfg(feature = "cuda")] Self::Cuda(_) => ::load_module_bytes(self, bytes), + #[cfg(feature = "metal")] + Self::Metal(_) => ::load_module_bytes(self, bytes), #[cfg(feature = "cpu")] Self::Cpu => ::load_module_bytes(self, bytes), } @@ -86,6 +104,8 @@ pub enum GpuBuffer { WebGpu(::Buffer), #[cfg(feature = "cuda")] Cuda(CudaBuffer), + #[cfg(feature = "metal")] + Metal(MetalBuffer), #[cfg(feature = "cpu")] Cpu(Vec), } @@ -158,6 +178,8 @@ impl<'a, T: DeviceValue> AsGpuSlice for GpuBufferSlice<'a, T> { Self::WebGpu(s) => GpuBufferSlice::WebGpu(*s), #[cfg(feature = "cuda")] Self::Cuda(s) => GpuBufferSlice::Cuda(*s), + #[cfg(feature = "metal")] + Self::Metal(s) => GpuBufferSlice::Metal(*s), #[cfg(feature = "cpu")] Self::Cpu(s) => GpuBufferSlice::Cpu(s), } @@ -171,6 +193,8 @@ impl<'a, T: DeviceValue> AsGpuSliceMut for GpuBufferSliceMut<'a, T> { Self::WebGpu(s) => GpuBufferSliceMut::WebGpu(*s), #[cfg(feature = "cuda")] Self::Cuda(s) => GpuBufferSliceMut::Cuda(*s), + #[cfg(feature = "metal")] + Self::Metal(s) => GpuBufferSliceMut::Metal(*s), #[cfg(feature = "cpu")] Self::Cpu(s) => GpuBufferSliceMut::Cpu(s), } @@ -210,6 +234,8 @@ pub enum GpuBufferSlice<'a, T: DeviceValue> { WebGpu(::BufferSlice<'a, T>), #[cfg(feature = "cuda")] Cuda(CudaBufferSlice), + #[cfg(feature = "metal")] + Metal(MetalBufferSlice<'a>), #[cfg(feature = "cpu")] Cpu(&'a [T]), } @@ -250,6 +276,8 @@ impl<'a, T: DeviceValue + bytemuck::Pod> GpuBufferSlice<'a, T> { Self::WebGpu(slice) => GpuBufferSlice::WebGpu(slice), #[cfg(feature = "cuda")] Self::Cuda(slice) => GpuBufferSlice::Cuda(slice), + #[cfg(feature = "metal")] + Self::Metal(slice) => GpuBufferSlice::Metal(slice), #[cfg(feature = "cpu")] Self::Cpu(slice) => GpuBufferSlice::Cpu(bytemuck::cast_slice(slice)), } @@ -290,6 +318,17 @@ impl<'a, T: DeviceValue + bytemuck::Pod> GpuBufferSlice<'a, T> { ); GpuBufferSlice::Cuda(slice) } + #[cfg(feature = "metal")] + Self::Metal(slice) => { + let target_size = core::mem::size_of::() as u64; + assert!( + target_size > 0 && slice.byte_len % target_size == 0, + "Cannot reinterpret Metal buffer: byte length {} is not a multiple of size_of::<{}>()", + slice.byte_len, + core::any::type_name::() + ); + GpuBufferSlice::Metal(slice) + } #[cfg(feature = "cpu")] Self::Cpu(slice) => GpuBufferSlice::Cpu(bytemuck::cast_slice(slice)), } @@ -303,6 +342,8 @@ pub enum GpuBufferSliceMut<'a, T: DeviceValue> { WebGpu(::BufferSlice<'a, T>), // TODO: add a mut version of ::BufferSlice? #[cfg(feature = "cuda")] Cuda(CudaBufferSlice), + #[cfg(feature = "metal")] + Metal(MetalBufferSlice<'a>), #[cfg(feature = "cpu")] Cpu(&'a mut [T]), } @@ -325,6 +366,8 @@ impl<'a, T: DeviceValue + bytemuck::Pod> GpuBufferSliceMut<'a, T> { Self::WebGpu(slice) => GpuBufferSliceMut::WebGpu(slice), #[cfg(feature = "cuda")] Self::Cuda(slice) => GpuBufferSliceMut::Cuda(slice), + #[cfg(feature = "metal")] + Self::Metal(slice) => GpuBufferSliceMut::Metal(slice), #[cfg(feature = "cpu")] Self::Cpu(slice) => GpuBufferSliceMut::Cpu(bytemuck::cast_slice_mut(slice)), } @@ -363,6 +406,17 @@ impl<'a, T: DeviceValue + bytemuck::Pod> GpuBufferSliceMut<'a, T> { ); GpuBufferSliceMut::Cuda(slice) } + #[cfg(feature = "metal")] + Self::Metal(slice) => { + let target_size = core::mem::size_of::() as u64; + assert!( + target_size > 0 && slice.byte_len % target_size == 0, + "Cannot reinterpret Metal buffer: byte length {} is not a multiple of size_of::<{}>()", + slice.byte_len, + core::any::type_name::() + ); + GpuBufferSliceMut::Metal(slice) + } #[cfg(feature = "cpu")] Self::Cpu(slice) => GpuBufferSliceMut::Cpu(bytemuck::cast_slice_mut(slice)), } @@ -387,6 +441,8 @@ pub enum GpuEncoder { WebGpu(::Encoder), #[cfg(feature = "cuda")] Cuda(CudaEncoderInner), + #[cfg(feature = "metal")] + Metal(MetalEncoderInner), #[cfg(feature = "cpu")] Cpu, Noop, @@ -400,6 +456,8 @@ pub enum GpuPass { WebGpu(::Pass), #[cfg(feature = "cuda")] Cuda(CudaPassInner), + #[cfg(feature = "metal")] + Metal(MetalPassInner), #[cfg(feature = "cpu")] Cpu(Option), Noop, @@ -440,6 +498,28 @@ impl GpuPass { matches!(self, Self::Cuda(..)) } + /// Inserts a buffer-scope memory barrier between dispatches within + /// this compute pass. + /// + /// Backends that auto-insert barriers between consecutive dispatches + /// (WebGPU/wgpu, CUDA's stream-ordered execution, CPU, Noop) treat + /// this as a no-op. The native Metal backend, which uses + /// `MTLDispatchType::Concurrent`, emits a real + /// `memoryBarrierWithScope:MTLBarrierScopeBuffers` so subsequent + /// dispatches see writes from earlier dispatches in the same pass. + /// + /// Call this between two dispatches inside a single + /// [`GpuEncoder::begin_pass`] when the second reads from a buffer + /// the first wrote to. No barrier is needed across pass boundaries — + /// `begin_pass` / `end_pass` already synchronize on every backend. + pub fn memory_barrier(&mut self) { + match self { + #[cfg(feature = "metal")] + Self::Metal(pass) => pass.memory_barrier(), + _ => {} + } + } + /// Begins a compute dispatch within this pass, binding the given function. pub fn begin_dispatch<'a>(&'a mut self, function: &'a InnerGpuFunction) -> GpuDispatch<'a> { match (self, function) { @@ -455,6 +535,17 @@ impl GpuPass { #[cfg(feature = "push_constants")] push_constants: Vec::new(), }), + #[cfg(feature = "metal")] + (Self::Metal(pass), InnerGpuFunction::Metal(f)) => { + pass.encoder.set_compute_pipeline_state(&f.pipeline); + GpuDispatch::Metal(MetalDispatchInner { + encoder: &pass.encoder, + function: f, + args: Vec::new(), + #[cfg(feature = "push_constants")] + push_constants: Vec::new(), + }) + } #[cfg(feature = "cpu")] (Self::Cpu(_), InnerGpuFunction::Noop) => GpuDispatch::Noop, (Self::Noop, InnerGpuFunction::Noop) => GpuDispatch::Noop, @@ -470,6 +561,8 @@ pub enum GpuModule { WebGpu(::Module), #[cfg(feature = "cuda")] Cuda(CudaModuleInner), + #[cfg(feature = "metal")] + Metal(MetalModuleInner), Noop, } @@ -505,6 +598,8 @@ pub enum GpuFunction { WebGpu(::Function, PhantomData), #[cfg(feature = "cuda")] Cuda(CudaFunctionInner, PhantomData), + #[cfg(feature = "metal")] + Metal(MetalFunctionInner, PhantomData), Noop(PhantomData), } @@ -519,6 +614,8 @@ pub enum InnerGpuFunction { WebGpu(::Function), #[cfg(feature = "cuda")] Cuda(CudaFunctionInner), + #[cfg(feature = "metal")] + Metal(MetalFunctionInner), Noop, } @@ -530,6 +627,8 @@ pub enum GpuDispatch<'a> { WebGpu(::Dispatch<'a>), #[cfg(feature = "cuda")] Cuda(CudaDispatchInner<'a>), + #[cfg(feature = "metal")] + Metal(MetalDispatchInner<'a>), Noop, #[doc(hidden)] _Phantom(std::marker::PhantomData<&'a ()>), @@ -545,6 +644,9 @@ pub enum GpuBackendError { #[cfg(feature = "cuda")] #[error(transparent)] Cuda(#[from] CudaBackendError), + #[cfg(feature = "metal")] + #[error(transparent)] + Metal(#[from] MetalBackendError), #[error(transparent)] ShaderArgs(#[from] ShaderArgsError), #[error("GPU context not found in local storage")] @@ -586,6 +688,8 @@ pub enum GpuTimestamps { WebGpu(WebGpuTimestamps), #[cfg(feature = "cuda")] Cuda(CudaTimestamps), + #[cfg(feature = "metal")] + Metal(MetalTimestamps), #[cfg(feature = "cpu")] Cpu(CpuTimestamps), Noop, @@ -603,6 +707,10 @@ impl GpuTimestamps { .unwrap_or(GpuTimestamps::Noop), #[cfg(feature = "cuda")] GpuBackend::Cuda(cuda) => GpuTimestamps::Cuda(CudaTimestamps::new(cuda)), + #[cfg(feature = "metal")] + GpuBackend::Metal(metal) => MetalTimestamps::new(metal, capacity) + .map(GpuTimestamps::Metal) + .unwrap_or(GpuTimestamps::Noop), #[cfg(feature = "cpu")] GpuBackend::Cpu => GpuTimestamps::Cpu(CpuTimestamps::new()), #[allow(unreachable_patterns)] @@ -622,6 +730,8 @@ impl GpuTimestamps { GpuTimestamps::WebGpu(ts) => ts.reset(), #[cfg(feature = "cuda")] GpuTimestamps::Cuda(ts) => ts.reset(), + #[cfg(feature = "metal")] + GpuTimestamps::Metal(ts) => ts.reset(), #[cfg(feature = "cpu")] GpuTimestamps::Cpu(ts) => ts.entries.lock().unwrap().clear(), GpuTimestamps::Noop => {} @@ -649,6 +759,8 @@ impl GpuTimestamps { (GpuTimestamps::WebGpu(ts), GpuBackend::WebGpu(webgpu)) => ts.read(webgpu).await, #[cfg(feature = "cuda")] (GpuTimestamps::Cuda(ts), _) => Ok(ts.read()?), + #[cfg(feature = "metal")] + (GpuTimestamps::Metal(ts), _) => Ok(ts.read()?), #[cfg(feature = "cpu")] (GpuTimestamps::Cpu(ts), _) => Ok(ts.entries.lock().unwrap().clone()), _ => Ok(Vec::new()), @@ -679,6 +791,8 @@ impl Backend for GpuBackend { Self::WebGpu(backend) => Ok(GpuModule::WebGpu(backend.load_module(data)?)), #[cfg(feature = "cuda")] Self::Cuda(backend) => Ok(GpuModule::Cuda(backend.load_module(data)?)), + #[cfg(feature = "metal")] + Self::Metal(backend) => Ok(GpuModule::Metal(backend.load_module(data)?)), #[cfg(feature = "cpu")] Self::Cpu => Ok(GpuModule::Noop), } @@ -690,6 +804,8 @@ impl Backend for GpuBackend { Self::WebGpu(backend) => Ok(GpuModule::WebGpu(backend.load_module_bytes(bytes)?)), #[cfg(feature = "cuda")] Self::Cuda(backend) => Ok(GpuModule::Cuda(backend.load_module_bytes(bytes)?)), + #[cfg(feature = "metal")] + Self::Metal(backend) => Ok(GpuModule::Metal(backend.load_module_bytes(bytes)?)), #[cfg(feature = "cpu")] Self::Cpu => Ok(GpuModule::Noop), } @@ -710,6 +826,10 @@ impl Backend for GpuBackend { (Self::Cuda(backend), GpuModule::Cuda(module)) => Ok(InnerGpuFunction::Cuda( backend.load_function(module, entry_point, push_constant_size)?, )), + #[cfg(feature = "metal")] + (Self::Metal(backend), GpuModule::Metal(module)) => Ok(InnerGpuFunction::Metal( + backend.load_function(module, entry_point, push_constant_size)?, + )), #[cfg(feature = "cpu")] (Self::Cpu, GpuModule::Noop) => Ok(InnerGpuFunction::Noop), _ => panic!("Invalid backend/module type pair"), @@ -742,6 +862,15 @@ impl Backend for GpuBackend { layouts, )?)) } + #[cfg(feature = "metal")] + (Self::Metal(backend), GpuModule::Metal(module)) => Ok(InnerGpuFunction::Metal( + backend.load_function_with_layouts( + module, + entry_point, + push_constant_size, + layouts, + )?, + )), #[cfg(feature = "cpu")] (Self::Cpu, GpuModule::Noop) => Ok(InnerGpuFunction::Noop), _ => panic!("Invalid backend/module type pair"), @@ -757,6 +886,8 @@ impl Backend for GpuBackend { Self::WebGpu(backend) => GpuEncoder::WebGpu(backend.begin_encoding()), #[cfg(feature = "cuda")] Self::Cuda(backend) => GpuEncoder::Cuda(backend.begin_encoding()), + #[cfg(feature = "metal")] + Self::Metal(backend) => GpuEncoder::Metal(backend.begin_encoding()), #[cfg(feature = "cpu")] Self::Cpu => GpuEncoder::Cpu, } @@ -776,6 +907,10 @@ impl Backend for GpuBackend { (Self::Cuda(backend), GpuPass::Cuda(pass), InnerGpuFunction::Cuda(function)) => { GpuDispatch::Cuda(backend.begin_dispatch(pass, function)) } + #[cfg(feature = "metal")] + (Self::Metal(backend), GpuPass::Metal(pass), InnerGpuFunction::Metal(function)) => { + GpuDispatch::Metal(backend.begin_dispatch(pass, function)) + } #[cfg(feature = "cpu")] (Self::Cpu, GpuPass::Cpu(_), InnerGpuFunction::Noop) => GpuDispatch::Noop, (_, GpuPass::Noop, InnerGpuFunction::Noop) => GpuDispatch::Noop, @@ -789,6 +924,8 @@ impl Backend for GpuBackend { (Self::WebGpu(backend), GpuEncoder::WebGpu(encoder)) => Ok(backend.submit(encoder)?), #[cfg(feature = "cuda")] (Self::Cuda(backend), GpuEncoder::Cuda(encoder)) => Ok(backend.submit(encoder)?), + #[cfg(feature = "metal")] + (Self::Metal(backend), GpuEncoder::Metal(encoder)) => Ok(backend.submit(encoder)?), #[cfg(feature = "cpu")] (Self::Cpu, GpuEncoder::Cpu) => Ok(()), _ => panic!("Invalid backend/encoder type pair"), @@ -808,6 +945,8 @@ impl Backend for GpuBackend { GpuBackend::WebGpu(backend) => Ok(GpuBuffer::WebGpu(backend.init_buffer(data, usage)?)), #[cfg(feature = "cuda")] GpuBackend::Cuda(backend) => Ok(GpuBuffer::Cuda(backend.init_buffer(data, usage)?)), + #[cfg(feature = "metal")] + GpuBackend::Metal(backend) => Ok(GpuBuffer::Metal(backend.init_buffer(data, usage)?)), #[cfg(feature = "cpu")] GpuBackend::Cpu => Ok(GpuBuffer::Cpu(data.to_vec())), } @@ -827,6 +966,10 @@ impl Backend for GpuBackend { GpuBackend::Cuda(backend) => { Ok(GpuBuffer::Cuda(backend.uninit_buffer::(len, usage)?)) } + #[cfg(feature = "metal")] + GpuBackend::Metal(backend) => { + Ok(GpuBuffer::Metal(backend.uninit_buffer::(len, usage)?)) + } #[cfg(feature = "cpu")] GpuBackend::Cpu => { let mut v = Vec::with_capacity(len); @@ -852,6 +995,10 @@ impl Backend for GpuBackend { (GpuBackend::Cuda(backend), GpuBuffer::Cuda(buffer)) => { backend.write_buffer(buffer, offset, data)? } + #[cfg(feature = "metal")] + (GpuBackend::Metal(backend), GpuBuffer::Metal(buffer)) => { + backend.write_buffer(buffer, offset, data)? + } #[cfg(feature = "cpu")] (GpuBackend::Cpu, GpuBuffer::Cpu(buffer)) => { let start = offset as usize; @@ -870,6 +1017,8 @@ impl Backend for GpuBackend { GpuBackend::WebGpu(backend) => Ok(backend.synchronize()?), #[cfg(feature = "cuda")] GpuBackend::Cuda(backend) => Ok(backend.synchronize()?), + #[cfg(feature = "metal")] + GpuBackend::Metal(backend) => Ok(backend.synchronize()?), #[cfg(feature = "cpu")] GpuBackend::Cpu => Ok(()), } @@ -889,6 +1038,10 @@ impl Backend for GpuBackend { (GpuBackend::Cuda(backend), GpuBuffer::Cuda(buffer)) => { backend.read_buffer(buffer, out).await? } + #[cfg(feature = "metal")] + (GpuBackend::Metal(backend), GpuBuffer::Metal(buffer)) => { + backend.read_buffer(buffer, out).await? + } #[cfg(feature = "cpu")] (GpuBackend::Cpu, GpuBuffer::Cpu(buffer)) => { out[..buffer.len()].copy_from_slice(buffer); @@ -913,6 +1066,10 @@ impl Backend for GpuBackend { (GpuBackend::Cuda(backend), GpuBuffer::Cuda(buffer)) => { backend.slow_read_buffer(buffer, out).await? } + #[cfg(feature = "metal")] + (GpuBackend::Metal(backend), GpuBuffer::Metal(buffer)) => { + backend.slow_read_buffer(buffer, out).await? + } #[cfg(feature = "cpu")] (GpuBackend::Cpu, GpuBuffer::Cpu(buffer)) => { out[..buffer.len()].copy_from_slice(buffer); @@ -943,6 +1100,14 @@ impl Encoder for GpuEncoder { GpuPass::Cuda(encoder.begin_pass(label, None)) } } + #[cfg(feature = "metal")] + GpuEncoder::Metal(encoder) => { + if let Some(GpuTimestamps::Metal(ts)) = timestamps { + GpuPass::Metal(encoder.begin_pass(label, Some(ts))) + } else { + GpuPass::Metal(encoder.begin_pass(label, None)) + } + } #[cfg(feature = "cpu")] GpuEncoder::Cpu => { let timer = if let Some(GpuTimestamps::Cpu(ts)) = timestamps { @@ -991,6 +1156,17 @@ impl Encoder for GpuEncoder { copy_len, )?; } + #[cfg(feature = "metal")] + (GpuEncoder::Metal(encoder), GpuBuffer::Metal(source), GpuBuffer::Metal(target)) => { + Encoder::::copy_buffer_to_buffer::( + encoder, + source, + source_offset, + target, + target_offset, + copy_len, + )?; + } #[cfg(feature = "cpu")] (GpuEncoder::Cpu, GpuBuffer::Cpu(source), GpuBuffer::Cpu(target)) => { target[target_offset..target_offset + copy_len] @@ -1000,6 +1176,19 @@ impl Encoder for GpuEncoder { } Ok(()) } + + fn memory_barrier(&mut self, pass: &mut GpuPass) { + match (self, pass) { + #[cfg(feature = "metal")] + (GpuEncoder::Metal(encoder), GpuPass::Metal(pass)) => { + Encoder::::memory_barrier(encoder, pass); + } + // Backends that already auto-insert barriers between dispatches + // (WebGPU/wgpu, CUDA's stream-ordered execution, CPU, Noop) treat + // this as a no-op. + _ => {} + } + } } impl<'a> Dispatch<'a, GpuBackend> for GpuDispatch<'a> { @@ -1010,6 +1199,8 @@ impl<'a> Dispatch<'a, GpuBackend> for GpuDispatch<'a> { GpuDispatch::WebGpu(dispatch) => dispatch.set_push_constants(data), #[cfg(feature = "cuda")] GpuDispatch::Cuda(dispatch) => dispatch.set_push_constants(data), + #[cfg(feature = "metal")] + GpuDispatch::Metal(dispatch) => dispatch.set_push_constants(data), GpuDispatch::Noop => {} GpuDispatch::_Phantom(_) => unreachable!(), } @@ -1049,6 +1240,20 @@ impl<'a> Dispatch<'a, GpuBackend> for GpuDispatch<'a> { }; dispatch.launch(cuda_grid, block_dim)?; } + #[cfg(feature = "metal")] + GpuDispatch::Metal(dispatch) => { + let grid: DispatchGrid<'b, GpuBackend> = grid.into(); + let metal_grid = match grid { + DispatchGrid::Grid(dims) => DispatchGrid::Grid(dims), + DispatchGrid::ThreadCount(threads) => DispatchGrid::ThreadCount(threads), + DispatchGrid::Indirect(buffer) => match buffer { + GpuBuffer::Metal(buf) => DispatchGrid::Indirect(buf), + #[allow(unreachable_patterns)] + _ => panic!("Invalid buffer type for Metal dispatch"), + }, + }; + dispatch.launch(metal_grid, block_dim)?; + } GpuDispatch::Noop => {} GpuDispatch::_Phantom(_) => unreachable!(), } @@ -1067,6 +1272,8 @@ impl CommandEncoderExt for GpuEncoder { GpuEncoder::WebGpu(encoder) => encoder.compute_pass(label), #[cfg(feature = "cuda")] GpuEncoder::Cuda(_) => panic!("Cannot create compute pass from non-WebGpu encoder"), + #[cfg(feature = "metal")] + GpuEncoder::Metal(_) => panic!("Cannot create compute pass from non-WebGpu encoder"), #[cfg(feature = "cpu")] GpuEncoder::Cpu => panic!("Cannot create compute pass from non-WebGpu encoder"), GpuEncoder::Noop => panic!("Cannot create compute pass from non-WebGpu encoder"), @@ -1100,6 +1307,11 @@ impl<'b, T: DeviceValue> crate::ShaderArgs<'b> for GpuBuffer { dispatch.set_arg(binding, buffer.device_ptr_raw(), buffer.byte_len()); Ok(()) } + #[cfg(feature = "metal")] + (GpuBuffer::Metal(buffer), GpuDispatch::Metal(dispatch)) => { + dispatch.set_arg(binding, buffer.raw(), 0, buffer.byte_len() as u64); + Ok(()) + } #[cfg(feature = "cpu")] (GpuBuffer::Cpu(_), GpuDispatch::Noop) => Ok(()), _ => panic!("Invalid buffer/dispatch type combination"), @@ -1127,6 +1339,16 @@ impl<'b, T: DeviceValue> crate::ShaderArgs<'b> for GpuBufferSlice<'_, T> { dispatch.set_arg(binding, slice.offset_ptr(), slice.byte_len); Ok(()) } + #[cfg(feature = "metal")] + (GpuBufferSlice::Metal(slice), GpuDispatch::Metal(dispatch)) => { + dispatch.set_arg( + binding, + slice.buffer(), + slice.byte_offset(), + slice.byte_len(), + ); + Ok(()) + } #[cfg(feature = "cpu")] (GpuBufferSlice::Cpu(_), GpuDispatch::Noop) => Ok(()), _ => panic!("Invalid buffer slice/dispatch type combination"), @@ -1154,6 +1376,16 @@ impl<'b, T: DeviceValue> crate::ShaderArgs<'b> for GpuBufferSliceMut<'_, T> { dispatch.set_arg(binding, slice.offset_ptr(), slice.byte_len); Ok(()) } + #[cfg(feature = "metal")] + (GpuBufferSliceMut::Metal(slice), GpuDispatch::Metal(dispatch)) => { + dispatch.set_arg( + binding, + slice.buffer(), + slice.byte_offset(), + slice.byte_len(), + ); + Ok(()) + } #[cfg(feature = "cpu")] (GpuBufferSliceMut::Cpu(_), GpuDispatch::Noop) => Ok(()), _ => panic!("Invalid mutable buffer slice/dispatch type combination"), @@ -1176,6 +1408,11 @@ impl GpuBuffer { use crate::backend::Buffer; GpuBufferSliceMut::Cuda(Buffer::::slice(buffer, range)) } + #[cfg(feature = "metal")] + GpuBuffer::Metal(buffer) => { + use crate::backend::Buffer; + GpuBufferSliceMut::Metal(Buffer::::slice(buffer, range)) + } #[cfg(feature = "cpu")] GpuBuffer::Cpu(buffer) => { use std::ops::Bound; @@ -1205,6 +1442,11 @@ impl crate::backend::Buffer for GpuBuffer { use crate::backend::Buffer; Buffer::::is_empty(buffer) } + #[cfg(feature = "metal")] + GpuBuffer::Metal(buffer) => { + use crate::backend::Buffer; + Buffer::::is_empty(buffer) + } #[cfg(feature = "cpu")] GpuBuffer::Cpu(buffer) => buffer.is_empty(), } @@ -1222,6 +1464,11 @@ impl crate::backend::Buffer for GpuBuffer { use crate::backend::Buffer; Buffer::::len(buffer) } + #[cfg(feature = "metal")] + GpuBuffer::Metal(buffer) => { + use crate::backend::Buffer; + Buffer::::len(buffer) + } #[cfg(feature = "cpu")] GpuBuffer::Cpu(buffer) => buffer.len(), } @@ -1239,6 +1486,11 @@ impl crate::backend::Buffer for GpuBuffer { use crate::backend::Buffer; GpuBufferSlice::Cuda(Buffer::::slice(buffer, range)) } + #[cfg(feature = "metal")] + GpuBuffer::Metal(buffer) => { + use crate::backend::Buffer; + GpuBufferSlice::Metal(Buffer::::slice(buffer, range)) + } #[cfg(feature = "cpu")] GpuBuffer::Cpu(buffer) => { use std::ops::Bound; @@ -1271,6 +1523,11 @@ impl crate::backend::Buffer for GpuBuffer { use crate::backend::Buffer; Buffer::::usage(buffer) } + #[cfg(feature = "metal")] + GpuBuffer::Metal(buffer) => { + use crate::backend::Buffer; + Buffer::::usage(buffer) + } #[cfg(feature = "cpu")] GpuBuffer::Cpu(_) => BufferUsages::all(), // CPU buffers have no usage restrictions } diff --git a/crates/khal/src/backend/metal.rs b/crates/khal/src/backend/metal.rs new file mode 100644 index 0000000..126cff9 --- /dev/null +++ b/crates/khal/src/backend/metal.rs @@ -0,0 +1,1179 @@ +//! Native Metal compute backend. +//! +//! Translates SPIR-V to MSL via `naga` at function-load time, then uses +//! Apple's Metal API directly through the [`metal`] crate. Unlike the +//! WebGPU backend (which goes through `wgpu`), this backend tracks no +//! per-resource state and inserts no implicit barriers between dispatches — +//! callers must use [`crate::backend::Encoder::memory_barrier`] when a +//! later dispatch reads from a buffer written by an earlier one in the +//! same compute pass. + +use crate::backend::{ + Backend, BufferUsages, CompileTarget, DescriptorType, DeviceValue, Dispatch, DispatchGrid, + Encoder, GpuTimestamp, MaybeSendSync, ShaderBinding, +}; +use crate::shader::{BindGroupLayoutInfo, ShaderArgsError}; +use bytemuck::{AnyBitPattern, NoUninit}; +// metal re-exports objc; pull in its macros so msg_send! / sel! resolve. +use metal::objc::runtime::Object; +use metal::objc::{msg_send, sel, sel_impl}; +use metal::{ + Buffer as MtlBuffer, CommandBuffer, CommandQueue, ComputeCommandEncoder, ComputePassDescriptor, + ComputePipelineDescriptor, ComputePipelineState, CounterSampleBuffer, + CounterSampleBufferDescriptor, CounterSet, Device, Library, MTLCounterSamplingPoint, + MTLDispatchType, MTLResourceOptions, MTLSize, MTLStorageMode, NSRange, NSUInteger, +}; +use std::collections::{BTreeMap, HashMap}; +use std::marker::PhantomData; +use std::ops::RangeBounds; +use std::sync::{Arc, Mutex}; + +// ── Core backend ─────────────────────────────────────────────────────── + +/// Native Metal backend. Wraps an [`MTLDevice`](metal::Device) and queue. +#[derive(Clone)] +pub struct Metal { + device: Device, + queue: CommandQueue, + /// Cache of compiled MSL libraries keyed by SPIR-V content hash. + module_cache: Arc>>, + /// Capabilities for GPU timestamp queries; `None` if the device or + /// driver doesn't expose stage-boundary timestamp sampling. + timing_caps: Option>, +} + +// SAFETY: metal::Device and CommandQueue are thread-safe (MTLDevice/MTLCommandQueue +// are documented as Send+Sync by Apple). The metal-rs crate doesn't auto-derive +// these, but the underlying Objective-C objects are. +unsafe impl Send for Metal {} +unsafe impl Sync for Metal {} + +/// Cached info needed to issue GPU timestamp queries on this device. +struct MetalTimingCaps { + /// The device's "timestamp" common counter set. + counter_set: CounterSet, + /// Multiplier converting raw timestamp ticks to nanoseconds. + /// `1.0` on Apple Silicon and AMD; `83.333…` on older Intel iGPUs. + period_ns: f64, +} + +// SAFETY: CounterSet wraps an MTLCounterSet which is thread-safe. +unsafe impl Send for MetalTimingCaps {} +unsafe impl Sync for MetalTimingCaps {} + +impl Metal { + /// Creates a new Metal backend using the system default device. + pub fn new() -> Result { + let device = Device::system_default().ok_or(MetalBackendError::NoDevice)?; + let queue = device.new_command_queue(); + let timing_caps = detect_timing_caps(&device); + Ok(Self { + device, + queue, + module_cache: Arc::new(Mutex::new(HashMap::new())), + timing_caps, + }) + } + + /// Returns the underlying Metal device. + pub fn device(&self) -> &Device { + &self.device + } + + /// Returns the command queue used by this backend. + pub fn queue(&self) -> &CommandQueue { + &self.queue + } + + /// Whether GPU timestamp queries are supported on this device. + pub fn timestamp_supported(&self) -> bool { + self.timing_caps.is_some() + } +} + +/// Probes the device for timestamp counter support. Returns `None` if either +/// stage-boundary sampling isn't supported or no "timestamp" counter set +/// exists (e.g. on older drivers or virtualized devices). +fn detect_timing_caps(device: &Device) -> Option> { + if !device.supports_counter_sampling(MTLCounterSamplingPoint::AtStageBoundary) { + return None; + } + let counter_set = device + .counter_sets() + .into_iter() + .find(|cs| cs.name() == "timestamp")?; + // Match the heuristic wgpu-hal uses: pre-Apple-Silicon Intel iGPUs report + // timestamps in ~83.333 ns units; everything else reports nanoseconds. + let period_ns = if device.name().starts_with("Intel") { + 83.333 + } else { + 1.0 + }; + Some(Arc::new(MetalTimingCaps { + counter_set, + period_ns, + })) +} + +// ── Error ────────────────────────────────────────────────────────────── + +/// Errors specific to the Metal backend. +#[derive(thiserror::Error, Debug)] +pub enum MetalBackendError { + #[error(transparent)] + ShaderArg(#[from] ShaderArgsError), + #[error("No Metal device available")] + NoDevice, + #[error("Failed to parse SPIR-V: {0}")] + SpirVParse(String), + #[error("Naga validation failed: {0}")] + NagaValidation(String), + #[error("Failed to write MSL: {0}")] + MslWrite(String), + #[error("Metal library compilation failed: {0}")] + LibraryCompile(String), + #[error("Metal pipeline creation failed: {0}")] + PipelineCreate(String), + #[error("Entry point `{0}` not found in module")] + EntryPointNotFound(String), +} + +// ── Buffer ───────────────────────────────────────────────────────────── + +/// A Metal device buffer with element count and usage metadata. +pub struct MetalBuffer { + inner: MtlBuffer, + len: usize, + byte_len: usize, + usage: BufferUsages, + _marker: PhantomData, +} + +// SAFETY: MTLBuffer is documented thread-safe by Apple. +unsafe impl Send for MetalBuffer {} +unsafe impl Sync for MetalBuffer {} + +impl MetalBuffer { + /// Returns the underlying Metal buffer. + pub fn raw(&self) -> &MtlBuffer { + &self.inner + } + + /// Returns the total byte length of this buffer. + pub fn byte_len(&self) -> usize { + self.byte_len + } +} + +// ── Buffer slice ─────────────────────────────────────────────────────── + +/// An immutable view into a Metal device buffer. +#[derive(Clone, Copy)] +pub struct MetalBufferSlice<'a> { + pub(crate) buffer: &'a MtlBuffer, + pub(crate) byte_offset: u64, + pub(crate) byte_len: u64, +} + +impl<'a> MetalBufferSlice<'a> { + /// Returns the underlying Metal buffer. + pub fn buffer(&self) -> &'a MtlBuffer { + self.buffer + } + + /// Byte offset into the underlying buffer. + pub fn byte_offset(&self) -> u64 { + self.byte_offset + } + + /// Byte length of this slice. + pub fn byte_len(&self) -> u64 { + self.byte_len + } +} + +// ── Module / Function ────────────────────────────────────────────────── + +/// A loaded shader module: parsed naga IR ready to be translated to MSL +/// per-entry-point at function load time. +#[derive(Clone)] +pub struct MetalModule { + naga: Arc, + info: Arc, + /// Type layouts (size, alignment) for every type in `naga.types`. + /// Used to compute threadgroup memory allocations. + layouter: Arc, +} + +/// A Metal compute pipeline plus the (group, binding) → MSL buffer slot map. +#[derive(Clone)] +pub struct MetalFunction { + pub(crate) pipeline: ComputePipelineState, + /// Sorted bindings (group, binding, descriptor_type) → MSL buffer slot. + pub(crate) slot_map: Arc>, // (group, binding, slot) + /// MSL buffer slot for push constants, if any. + #[cfg(feature = "push_constants")] + pub(crate) push_constant_slot: Option, + /// MSL buffer slot for the runtime-array sizes buffer. + /// + /// `None` means no runtime-array bindings are present in this entry + /// point and naga does not emit a `_mslBufferSizes` struct. + pub(crate) sizes_slot: Option, + /// Bindings that contribute a `uint sizeN;` field to `_mslBufferSizes`, + /// in declaration order. At dispatch time we look up the byte length of + /// each bound buffer and pack them into a `Vec` that is bound at + /// [`sizes_slot`](Self::sizes_slot) via `setBytes`. + pub(crate) sizes_bindings: Arc>, + /// Threadgroup memory sizes (in bytes) to allocate, indexed by the + /// implicit MSL threadgroup buffer index assigned by naga (declaration + /// order). At dispatch we call `setThreadgroupMemoryLength:atIndex:` + /// for each entry. + pub(crate) threadgroup_sizes: Arc>, + /// Workgroup size declared in the shader. Used for indirect dispatch. + pub(crate) workgroup_size: [u32; 3], +} + +// SAFETY: ComputePipelineState wraps MTLComputePipelineState, which is thread-safe. +unsafe impl Send for MetalFunction {} +unsafe impl Sync for MetalFunction {} + +// ── Encoder / Pass ───────────────────────────────────────────────────── + +/// Metal command encoder. Owns the in-flight `MTLCommandBuffer`. +pub struct MetalEncoder { + pub(crate) command_buffer: CommandBuffer, +} + +unsafe impl Send for MetalEncoder {} +unsafe impl Sync for MetalEncoder {} + +/// An active Metal compute pass. Holds the `MTLComputeCommandEncoder` +/// for the duration of the pass and ends encoding on drop. +pub struct MetalPass { + pub(crate) encoder: ComputeCommandEncoder, +} + +unsafe impl Send for MetalPass {} +unsafe impl Sync for MetalPass {} + +impl Drop for MetalPass { + fn drop(&mut self) { + self.encoder.end_encoding(); + } +} + +impl MetalPass { + /// Inserts a buffer-scope memory barrier into this compute pass. + /// + /// metal-rs 0.32 doesn't wrap `memoryBarrierWithScope:`, so the + /// selector is invoked directly via `msg_send!`. + pub fn memory_barrier(&mut self) { + const MTL_BARRIER_SCOPE_BUFFERS: u64 = 1; + unsafe { + let _: () = msg_send![self.encoder.as_ref(), + memoryBarrierWithScope: MTL_BARRIER_SCOPE_BUFFERS]; + } + } +} + +// ── Dispatch ─────────────────────────────────────────────────────────── + +/// Collects buffer arguments and launches a Metal compute kernel. +pub struct MetalDispatch<'a> { + pub(crate) encoder: &'a ComputeCommandEncoder, + pub(crate) function: &'a MetalFunction, + /// Collected (binding, buffer ref, byte_offset, byte_len) tuples. + pub(crate) args: Vec<(ShaderBinding, &'a MtlBuffer, u64, u64)>, + #[cfg(feature = "push_constants")] + pub(crate) push_constants: Vec, +} + +impl<'a> MetalDispatch<'a> { + /// Adds a buffer argument at the given binding location. + pub fn set_arg( + &mut self, + binding: ShaderBinding, + buffer: &'a MtlBuffer, + byte_offset: u64, + byte_len: u64, + ) { + self.args.push((binding, buffer, byte_offset, byte_len)); + } + + /// Sets push constant data for this dispatch. + #[cfg(feature = "push_constants")] + pub fn set_push_constants(&mut self, data: &[u8]) { + self.push_constants.clear(); + self.push_constants.extend_from_slice(data); + } +} + +// ── Timestamps ───────────────────────────────────────────────────────── + +/// GPU timestamp manager backed by an `MTLCounterSampleBuffer` sampling at +/// stage boundaries (begin/end of each compute pass). +/// +/// Each `begin_pass` allocates a `(start_index, end_index)` pair and wires +/// up the active compute pass descriptor's sample buffer attachment so +/// Metal records the GPU timestamp at the start and end of the encoder. +pub struct MetalTimestamps { + sample_buffer: CounterSampleBuffer, + /// `capacity * 2` total slots (one begin + one end per pass). + capacity: u32, + /// Number of `(begin, end)` pairs allocated so far this frame. + next_index: u32, + /// Pass labels in allocation order; aligned with sample-pair indices. + labels: Vec, + /// Tick → nanosecond multiplier captured from the backend at creation. + period_ns: f64, +} + +// SAFETY: CounterSampleBuffer wraps an MTLCounterSampleBuffer, thread-safe. +unsafe impl Send for MetalTimestamps {} +unsafe impl Sync for MetalTimestamps {} + +impl MetalTimestamps { + /// Creates a new timestamp manager with room for `capacity` timed + /// passes. Returns `None` if the device doesn't expose stage-boundary + /// timestamp sampling, or if allocating the sample buffer fails. + pub fn new(metal: &Metal, capacity: u32) -> Option { + if capacity == 0 { + return None; + } + let caps = metal.timing_caps.as_ref()?; + let descriptor = CounterSampleBufferDescriptor::new(); + descriptor.set_counter_set(&caps.counter_set); + descriptor.set_storage_mode(MTLStorageMode::Shared); + descriptor.set_sample_count((capacity as u64) * 2); + let sample_buffer = metal + .device + .new_counter_sample_buffer_with_descriptor(&descriptor) + .ok()?; + Some(MetalTimestamps { + sample_buffer, + capacity, + next_index: 0, + labels: Vec::with_capacity(capacity as usize), + period_ns: caps.period_ns, + }) + } + + /// Resets the manager for a new frame (drops all collected pairs). + pub fn reset(&mut self) { + self.next_index = 0; + self.labels.clear(); + } + + /// Reads back timestamp results after GPU synchronization. + /// + /// Must be called after the encoder containing the timed passes has + /// been submitted *and* the device has been synchronized — otherwise + /// the resolved values are unspecified. + pub fn read(&self) -> Result, MetalBackendError> { + if self.next_index == 0 { + return Ok(Vec::new()); + } + let count = self.next_index as u64 * 2; + // metal-rs 0.32's `resolve_counter_range` wrapper has a bug where + // it always passes length=0 to `getBytes:length:`, so we call + // `resolveCounterRange:` ourselves and copy the bytes out manually. + // Each `MTLCounterResultTimestamp` is a single u64 tick value. + let mut raw = vec![0u64; count as usize]; + unsafe { + let range = NSRange { + location: 0, + length: count, + }; + let ns_data: *mut Object = + msg_send![self.sample_buffer.as_ref(), resolveCounterRange: range]; + if !ns_data.is_null() { + let total_bytes = count * std::mem::size_of::() as u64; + let _: () = msg_send![ns_data, + getBytes: raw.as_mut_ptr() as *mut std::ffi::c_void + length: total_bytes]; + } + } + let mut entries = Vec::with_capacity(self.labels.len()); + for (i, label) in self.labels.iter().enumerate() { + let begin = raw.get(i * 2).copied().unwrap_or(0); + let end = raw.get(i * 2 + 1).copied().unwrap_or(0); + let ticks = end.saturating_sub(begin) as f64; + entries.push(GpuTimestamp { + label: label.clone(), + duration_ms: ticks * self.period_ns / 1_000_000.0, + }); + } + Ok(entries) + } + + /// Allocates a `(begin, end)` index pair for a labeled pass. Returns + /// `None` if the sample buffer is full. + fn alloc_pair(&mut self, label: &str) -> Option<(NSUInteger, NSUInteger)> { + if self.next_index >= self.capacity { + return None; + } + let begin = (self.next_index * 2) as NSUInteger; + let end = begin + 1; + self.next_index += 1; + self.labels.push(label.to_string()); + Some((begin, end)) + } +} + +// ── Backend trait impl ───────────────────────────────────────────────── + +impl Backend for Metal { + const NAME: &'static str = "metal"; + const TARGET: CompileTarget = CompileTarget::Spirv; + + type Error = MetalBackendError; + type Buffer = MetalBuffer; + type BufferSlice<'b, T: DeviceValue> = MetalBufferSlice<'b>; + type Encoder = MetalEncoder; + type Pass = MetalPass; + type Timestamps = MetalTimestamps; + type Module = MetalModule; + type Function = MetalFunction; + type Dispatch<'a> = MetalDispatch<'a>; + + fn as_metal(&self) -> Option<&Metal> { + Some(self) + } + + /* + * Module / function loading. + */ + fn load_module_bytes(&self, bytes: &[u8]) -> Result { + // Module cache is keyed by content hash so the same SPIR-V isn't reparsed. + let hash = fxhash(bytes); + { + let cache = self.module_cache.lock().unwrap(); + if let Some(module) = cache.get(&hash) { + return Ok(module.clone()); + } + } + + // Validate SPIR-V magic number. + if bytes.len() < 4 + || u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) != 0x07230203 + { + return Err(MetalBackendError::SpirVParse( + "Input does not start with SPIR-V magic number".into(), + )); + } + + // SPIR-V → naga IR. + let module = naga::front::spv::parse_u8_slice( + bytes, + &naga::front::spv::Options { + adjust_coordinate_space: false, + strict_capabilities: false, + block_ctx_dump_prefix: None, + }, + ) + .map_err(|e| MetalBackendError::SpirVParse(format!("{e}")))?; + + // Validate so the MSL backend has the type info it needs. + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::all(), + ) + .validate(&module) + .map_err(|e| MetalBackendError::NagaValidation(format!("{:?}", e.into_inner())))?; + + // Type layouts let us size threadgroup memory at dispatch time. + let mut layouter = naga::proc::Layouter::default(); + layouter + .update(module.to_ctx()) + .map_err(|e| MetalBackendError::NagaValidation(format!("layout: {e}")))?; + + let metal_module = MetalModule { + naga: Arc::new(module), + info: Arc::new(info), + layouter: Arc::new(layouter), + }; + + self.module_cache + .lock() + .unwrap() + .insert(hash, metal_module.clone()); + + Ok(metal_module) + } + + fn load_function( + &self, + module: &Self::Module, + entry_point: &str, + push_constant_size: u32, + ) -> Result { + // Without explicit layout info we infer bindings by scanning the naga module. + let layouts = layouts_from_module(&module.naga); + self.load_function_with_layouts(module, entry_point, push_constant_size, &layouts) + } + + fn load_function_with_layouts( + &self, + module: &Self::Module, + entry_point: &str, + push_constant_size: u32, + layouts: &BindGroupLayoutInfo, + ) -> Result { + // Collect all bindings, sort by (group, binding), assign MSL slots 0..N-1. + let mut all_bindings: Vec = layouts + .groups + .iter() + .flat_map(|g| g.iter().copied()) + .collect(); + all_bindings.sort_by_key(|b| (b.space, b.index)); + + let mut resources: BTreeMap = + BTreeMap::new(); + let mut slot_map: Vec<(u32, u32, u64)> = Vec::with_capacity(all_bindings.len()); + for (slot, binding) in all_bindings.iter().enumerate() { + let mutable = matches!( + binding.descriptor_type, + DescriptorType::Storage { read_only: false } + ); + let target = naga::back::msl::BindTarget { + buffer: Some(slot as u8), + texture: None, + sampler: None, + external_texture: None, + mutable, + }; + resources.insert( + naga::ResourceBinding { + group: binding.space, + binding: binding.index, + }, + target, + ); + slot_map.push((binding.space, binding.index, slot as u64)); + } + + let next_slot = all_bindings.len() as u8; + #[cfg(feature = "push_constants")] + let push_constant_slot: Option = if push_constant_size > 0 { + Some(next_slot) + } else { + None + }; + #[cfg(not(feature = "push_constants"))] + let push_constant_slot: Option = None; + let _ = push_constant_size; + + let next_slot = next_slot + push_constant_slot.is_some() as u8; + + // Collect the bindings that need a `uint sizeN;` entry in + // `_mslBufferSizes`, in the same handle iteration order naga uses. + // Each runtime-sized storage buffer contributes one entry. + let mut sizes_bindings: Vec<(u32, u32)> = Vec::new(); + // Threadgroup-memory globals: collect their byte sizes in handle + // iteration order. naga emits them as kernel parameters with no + // `[[threadgroup(N)]]` attribute, so MSL assigns them indices + // 0, 1, 2... in declaration order — matching this Vec's indices. + let mut threadgroup_sizes: Vec = Vec::new(); + for (_, var) in module.naga.global_variables.iter() { + if needs_array_length(var.ty, &module.naga.types) { + if let Some(b) = &var.binding { + sizes_bindings.push((b.group, b.binding)); + } + } + if matches!(var.space, naga::AddressSpace::WorkGroup) { + let layout = module.layouter[var.ty]; + threadgroup_sizes.push(layout.size); + } + } + let sizes_slot = if sizes_bindings.is_empty() { + None + } else { + Some(next_slot) + }; + + let entry_point_resources = naga::back::msl::EntryPointResources { + resources, + immediates_buffer: push_constant_slot, + sizes_buffer: sizes_slot, + }; + + let mut per_entry_point: BTreeMap = + BTreeMap::new(); + per_entry_point.insert(entry_point.to_string(), entry_point_resources); + + let options = naga::back::msl::Options { + lang_version: (2, 4), + per_entry_point_map: per_entry_point, + inline_samplers: vec![], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + bounds_check_policies: naga::proc::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: false, + force_loop_bounding: false, + }; + + let pipeline_options = naga::back::msl::PipelineOptions { + entry_point: Some((naga::ShaderStage::Compute, entry_point.to_string())), + allow_and_force_point_size: false, + vertex_pulling_transform: false, + vertex_buffer_mappings: vec![], + }; + + let mut msl = String::new(); + let mut writer = naga::back::msl::Writer::new(&mut msl); + let translation_info = writer + .write(&module.naga, &module.info, &options, &pipeline_options) + .map_err(|e| MetalBackendError::MslWrite(format!("{e}")))?; + + if std::env::var("KHAL_METAL_DUMP_MSL").is_ok() { + eprintln!("──── MSL for `{}` ────\n{}\n────────────", entry_point, msl); + } + + // Find the mangled MSL entry point name corresponding to our requested entry point. + let mangled = translation_info + .entry_point_names + .iter() + .zip(module.naga.entry_points.iter()) + .find_map(|(name_result, ep)| { + if ep.name == entry_point { + name_result.as_ref().ok().map(|n| n.clone()) + } else { + None + } + }) + .ok_or_else(|| MetalBackendError::EntryPointNotFound(entry_point.into()))?; + + // Workgroup size lives in the entry point. + let workgroup_size = module + .naga + .entry_points + .iter() + .find(|ep| ep.name == entry_point) + .map(|ep| ep.workgroup_size) + .ok_or_else(|| MetalBackendError::EntryPointNotFound(entry_point.into()))?; + + // Compile MSL. + let compile_options = metal::CompileOptions::new(); + let library: Library = self + .device + .new_library_with_source(&msl, &compile_options) + .map_err(MetalBackendError::LibraryCompile)?; + + let function = library + .get_function(&mangled, None) + .map_err(|e| MetalBackendError::EntryPointNotFound(format!("{entry_point}: {e}")))?; + + let descriptor = ComputePipelineDescriptor::new(); + descriptor.set_compute_function(Some(&function)); + descriptor.set_label(entry_point); + + let pipeline = self + .device + .new_compute_pipeline_state(&descriptor) + .map_err(MetalBackendError::PipelineCreate)?; + + Ok(MetalFunction { + pipeline, + slot_map: Arc::new(slot_map), + #[cfg(feature = "push_constants")] + push_constant_slot: push_constant_slot.map(|s| s as u64), + sizes_slot: sizes_slot.map(|s| s as u64), + sizes_bindings: Arc::new(sizes_bindings), + threadgroup_sizes: Arc::new(threadgroup_sizes), + workgroup_size, + }) + } + + /* + * Kernel dispatch. + */ + fn begin_encoding(&self) -> Self::Encoder { + let cmd_buf = self.queue.new_command_buffer().to_owned(); + MetalEncoder { + command_buffer: cmd_buf, + } + } + + fn begin_dispatch<'a>( + &'a self, + pass: &'a mut Self::Pass, + function: &'a Self::Function, + ) -> Self::Dispatch<'a> { + pass.encoder.set_compute_pipeline_state(&function.pipeline); + MetalDispatch { + encoder: &pass.encoder, + function, + args: Vec::new(), + #[cfg(feature = "push_constants")] + push_constants: Vec::new(), + } + } + + fn synchronize(&self) -> Result<(), Self::Error> { + // Submit and wait on a fresh empty command buffer to flush the queue. + let cb = self.queue.new_command_buffer(); + cb.commit(); + cb.wait_until_completed(); + Ok(()) + } + + fn submit(&self, encoder: Self::Encoder) -> Result<(), Self::Error> { + encoder.command_buffer.commit(); + Ok(()) + } + + /* + * Buffer handling. + */ + fn init_buffer( + &self, + data: &[T], + usage: BufferUsages, + ) -> Result, Self::Error> { + let bytes: &[u8] = bytemuck::cast_slice(data); + let len = data.len(); + let byte_len = bytes.len(); + let inner = if byte_len == 0 { + // MTLBuffer of length 0 isn't allowed; allocate a single byte placeholder. + self.device.new_buffer(1, resource_options(usage)) + } else { + self.device.new_buffer_with_data( + bytes.as_ptr() as _, + byte_len as NSUInteger, + resource_options(usage), + ) + }; + Ok(MetalBuffer { + inner, + len, + byte_len, + usage, + _marker: PhantomData, + }) + } + + fn uninit_buffer( + &self, + len: usize, + usage: BufferUsages, + ) -> Result, Self::Error> { + let elt_size = std::mem::size_of::(); + let byte_len = (len * elt_size).max(1); + let inner = self + .device + .new_buffer(byte_len as NSUInteger, resource_options(usage)); + Ok(MetalBuffer { + inner, + len, + byte_len, + usage, + _marker: PhantomData, + }) + } + + fn write_buffer( + &self, + buffer: &mut Self::Buffer, + offset: u64, + data: &[T], + ) -> Result<(), Self::Error> { + let elt_size = std::mem::size_of::(); + let byte_offset = (offset as usize) * elt_size; + let bytes: &[u8] = bytemuck::cast_slice(data); + if bytes.is_empty() { + return Ok(()); + } + // SAFETY: contents() is valid for the buffer's lifetime; we copy non-overlapping bytes. + unsafe { + let dst = (buffer.inner.contents() as *mut u8).add(byte_offset); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, bytes.len()); + } + // Modified-range hint helps the driver on managed/storage modes. + let range = metal::NSRange { + location: byte_offset as NSUInteger, + length: bytes.len() as NSUInteger, + }; + buffer.inner.did_modify_range(range); + Ok(()) + } + + async fn read_buffer( + &self, + buffer: &Self::Buffer, + out: &mut [T], + ) -> Result<(), Self::Error> { + // Make sure all prior submissions are visible. + self.synchronize()?; + if buffer.byte_len == 0 { + return Ok(()); + } + // SAFETY: contents() is valid for the buffer's lifetime; we copy non-overlapping bytes. + let out_bytes = std::mem::size_of_val(out); + let copy_len = buffer.byte_len.min(out_bytes); + unsafe { + std::ptr::copy_nonoverlapping( + buffer.inner.contents() as *const u8, + out.as_mut_ptr() as *mut u8, + copy_len, + ); + } + Ok(()) + } + + async fn slow_read_buffer( + &self, + buffer: &Self::Buffer, + out: &mut [T], + ) -> Result<(), Self::Error> { + // For shared-storage MTLBuffers, host pointer access is direct after sync, + // so this is identical to `read_buffer`. We blit through a staging buffer + // when the source isn't host-coherent (private storage). + if buffer.inner.storage_mode() == metal::MTLStorageMode::Private { + // Allocate shared staging, blit, sync, copy out. + let staging = self.uninit_buffer::( + buffer.byte_len, + BufferUsages::COPY_DST | BufferUsages::MAP_READ, + )?; + let encoder = self.begin_encoding(); + let blit = encoder.command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer( + &buffer.inner, + 0, + &staging.inner, + 0, + buffer.byte_len as NSUInteger, + ); + blit.end_encoding(); + self.submit(encoder)?; + self.synchronize()?; + + let out_bytes = std::mem::size_of_val(out); + let copy_len = staging.byte_len.min(out_bytes); + // SAFETY: contents() valid for staging's lifetime; copy non-overlapping. + unsafe { + std::ptr::copy_nonoverlapping( + staging.inner.contents() as *const u8, + out.as_mut_ptr() as *mut u8, + copy_len, + ); + } + return Ok(()); + } + self.read_buffer(buffer, out).await + } +} + +// ── Encoder ──────────────────────────────────────────────────────────── + +impl Encoder for MetalEncoder { + fn begin_pass(&mut self, label: &str, timestamps: Option<&mut MetalTimestamps>) -> MetalPass { + // If timestamps are requested and we still have room in the sample + // buffer, configure the compute pass descriptor's + // sample-buffer-attachment so Metal records GPU timestamps at the + // start/end of this encoder. + let encoder = match timestamps.and_then(|ts| { + let (begin, end) = ts.alloc_pair(label)?; + Some((ts, begin, end)) + }) { + Some((ts, begin, end)) => { + let descriptor = ComputePassDescriptor::new(); + descriptor.set_dispatch_type(MTLDispatchType::Serial); + let attachment = descriptor + .sample_buffer_attachments() + .object_at(0) + .expect("compute pass sample buffer attachment 0"); + attachment.set_sample_buffer(&ts.sample_buffer); + attachment.set_start_of_encoder_sample_index(begin); + attachment.set_end_of_encoder_sample_index(end); + self.command_buffer + .compute_command_encoder_with_descriptor(descriptor) + .to_owned() + } + None => self + .command_buffer + .compute_command_encoder_with_dispatch_type(MTLDispatchType::Serial) + .to_owned(), + }; + if !label.is_empty() { + encoder.set_label(label); + } + MetalPass { encoder } + } + + fn copy_buffer_to_buffer( + &mut self, + source: &::Buffer, + source_offset: usize, + target: &mut ::Buffer, + target_offset: usize, + copy_len: usize, + ) -> Result<(), MetalBackendError> { + if copy_len == 0 { + return Ok(()); + } + let elt_size = std::mem::size_of::(); + let blit = self.command_buffer.new_blit_command_encoder(); + blit.copy_from_buffer( + &source.inner, + (source_offset * elt_size) as NSUInteger, + &target.inner, + (target_offset * elt_size) as NSUInteger, + (copy_len * elt_size) as NSUInteger, + ); + blit.end_encoding(); + Ok(()) + } + + fn memory_barrier(&mut self, pass: &mut MetalPass) { + pass.memory_barrier(); + } +} + +// ── Dispatch ─────────────────────────────────────────────────────────── + +impl<'a> Dispatch<'a, Metal> for MetalDispatch<'a> { + #[cfg(feature = "push_constants")] + fn set_push_constants(&mut self, data: &[u8]) { + self.push_constants.clear(); + self.push_constants.extend_from_slice(data); + } + + fn launch<'b>( + self, + grid: impl Into>, + block_dim: [u32; 3], + ) -> Result<(), MetalBackendError> { + // Bind buffers per the (group, binding) → MSL slot map. + for (binding, buffer, byte_offset, _byte_len) in &self.args { + let slot = self + .function + .slot_map + .iter() + .find(|(g, b, _)| *g == binding.space && *b == binding.index) + .map(|(_, _, s)| *s); + if let Some(slot) = slot { + self.encoder + .set_buffer(slot as NSUInteger, Some(buffer), *byte_offset); + } + } + + // Push constants: bind inline via setBytes at the reserved slot. + #[cfg(feature = "push_constants")] + if let Some(slot) = self.function.push_constant_slot { + if !self.push_constants.is_empty() { + self.encoder.set_bytes( + slot as NSUInteger, + self.push_constants.len() as NSUInteger, + self.push_constants.as_ptr() as *const _, + ); + } + } + + // Threadgroup memory: naga emits WorkGroup-space globals as kernel + // parameters without explicit `[[threadgroup(N)]]` attributes, so + // MSL assigns implicit indices in declaration order. The host must + // size each slot via `setThreadgroupMemoryLength:atIndex:`. + for (idx, &size) in self.function.threadgroup_sizes.iter().enumerate() { + // Metal requires non-zero, 16-byte-aligned threadgroup sizes. + let aligned = ((size as NSUInteger) + 15) & !15; + let aligned = aligned.max(16); + self.encoder + .set_threadgroup_memory_length(idx as NSUInteger, aligned); + } + + // Runtime-array sizes buffer: naga emits bounds checks against + // `_mslBufferSizes.sizeN` for each storage buffer with a runtime + // array. We pack the byte length of each such buffer (in the order + // naga declared them) and bind via setBytes at `sizes_slot`. + if let Some(slot) = self.function.sizes_slot { + let mut sizes: smallvec::SmallVec<[u32; 8]> = + smallvec::SmallVec::with_capacity(self.function.sizes_bindings.len()); + for (group, binding) in self.function.sizes_bindings.iter() { + let entry = self + .args + .iter() + .find(|(b, _, _, _)| b.space == *group && b.index == *binding); + let byte_len = entry + .map(|(_, _, _, byte_len)| *byte_len as u32) + .unwrap_or(0); + sizes.push(byte_len); + } + if !sizes.is_empty() { + self.encoder.set_bytes( + slot as NSUInteger, + (sizes.len() * std::mem::size_of::()) as NSUInteger, + sizes.as_ptr() as *const _, + ); + } + } + + // Resolve grid dimensions. + let (grid_size, threads_per_threadgroup) = match grid.into() { + DispatchGrid::Grid(g) => ( + MTLSize { + width: g[0] as NSUInteger, + height: g[1] as NSUInteger, + depth: g[2] as NSUInteger, + }, + MTLSize { + width: block_dim[0] as NSUInteger, + height: block_dim[1] as NSUInteger, + depth: block_dim[2] as NSUInteger, + }, + ), + DispatchGrid::ThreadCount(t) => ( + MTLSize { + width: t[0].div_ceil(block_dim[0]) as NSUInteger, + height: t[1].div_ceil(block_dim[1]) as NSUInteger, + depth: t[2].div_ceil(block_dim[2]) as NSUInteger, + }, + MTLSize { + width: block_dim[0] as NSUInteger, + height: block_dim[1] as NSUInteger, + depth: block_dim[2] as NSUInteger, + }, + ), + DispatchGrid::Indirect(buffer) => { + self.encoder.dispatch_thread_groups_indirect( + &buffer.inner, + 0, + MTLSize { + width: block_dim[0] as NSUInteger, + height: block_dim[1] as NSUInteger, + depth: block_dim[2] as NSUInteger, + }, + ); + return Ok(()); + } + }; + + if grid_size.width == 0 || grid_size.height == 0 || grid_size.depth == 0 { + return Ok(()); + } + + self.encoder + .dispatch_thread_groups(grid_size, threads_per_threadgroup); + + Ok(()) + } +} + +// ── Buffer trait impl ────────────────────────────────────────────────── + +impl crate::backend::Buffer for MetalBuffer { + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn len(&self) -> usize + where + T: Sized, + { + self.len + } + + fn slice(&self, range: impl RangeBounds) -> MetalBufferSlice<'_> { + let elt_size = std::mem::size_of::() as u64; + let total = self.byte_len as u64; + let start = match range.start_bound() { + std::ops::Bound::Included(&n) => n as u64 * elt_size, + std::ops::Bound::Excluded(&n) => (n as u64 + 1) * elt_size, + std::ops::Bound::Unbounded => 0, + }; + let end = match range.end_bound() { + std::ops::Bound::Included(&n) => (n as u64 + 1) * elt_size, + std::ops::Bound::Excluded(&n) => n as u64 * elt_size, + std::ops::Bound::Unbounded => total, + }; + MetalBufferSlice { + buffer: &self.inner, + byte_offset: start, + byte_len: end - start, + } + } + + fn usage(&self) -> BufferUsages { + self.usage + } +} + +// ── Helpers ──────────────────────────────────────────────────────────── + +/// Maps khal usage flags to Metal resource options. +/// +/// `MAP_READ`/`MAP_WRITE` → shared storage (CPU/GPU coherent). +/// Otherwise → private storage (GPU-only, fastest). +fn resource_options(usage: BufferUsages) -> MTLResourceOptions { + if usage.intersects(BufferUsages::MAP_READ | BufferUsages::MAP_WRITE) { + MTLResourceOptions::StorageModeShared + } else { + // Private would require staging for any host upload; use Shared by + // default so init_buffer / write_buffer remain straightforward and + // match the semantics of wgpu's queue.write_buffer. + MTLResourceOptions::StorageModeShared + } +} + +/// Returns true if `ty` (or its trailing struct member) is a runtime-sized +/// array. naga emits a `uint sizeN;` field in `_mslBufferSizes` for each +/// global variable for which this returns true. +fn needs_array_length(ty: naga::Handle, types: &naga::UniqueArena) -> bool { + match types[ty].inner { + naga::TypeInner::Struct { ref members, .. } => { + if let Some(member) = members.last() { + if let naga::TypeInner::Array { + size: naga::ArraySize::Dynamic, + .. + } = types[member.ty].inner + { + return true; + } + } + false + } + naga::TypeInner::Array { + size: naga::ArraySize::Dynamic, + .. + } => true, + _ => false, + } +} + +/// Best-effort recovery of bind group layout info from a naga module. +/// Used as a fallback when [`Backend::load_function`] is called without +/// explicit layout info. +fn layouts_from_module(module: &naga::Module) -> BindGroupLayoutInfo { + let mut groups: Vec> = Vec::new(); + for (_, var) in module.global_variables.iter() { + let Some(binding) = &var.binding else { + continue; + }; + let descriptor_type = match var.space { + naga::AddressSpace::Uniform => DescriptorType::Uniform, + naga::AddressSpace::Storage { access } => DescriptorType::Storage { + read_only: !access.contains(naga::StorageAccess::STORE), + }, + _ => continue, + }; + let group = binding.group as usize; + if groups.len() <= group { + groups.resize_with(group + 1, Vec::new); + } + groups[group].push(ShaderBinding { + space: binding.group, + index: binding.binding, + descriptor_type, + }); + } + BindGroupLayoutInfo { groups } +} + +/// Simple FNV-1a hash for module cache keys. +fn fxhash(data: &[u8]) -> u64 { + let mut hash: u64 = 0xcbf29ce484222325; + for &b in data { + hash ^= b as u64; + hash = hash.wrapping_mul(0x100000001b3); + } + hash +} diff --git a/crates/khal/src/backend/mod.rs b/crates/khal/src/backend/mod.rs index c8cfccc..cee795f 100644 --- a/crates/khal/src/backend/mod.rs +++ b/crates/khal/src/backend/mod.rs @@ -24,6 +24,11 @@ pub use cuda::Cuda; #[cfg(feature = "cuda")] pub mod cuda; +#[cfg(feature = "metal")] +pub use metal::Metal; +#[cfg(feature = "metal")] +pub mod metal; + mod any_backend; pub use any_backend::*; @@ -180,6 +185,12 @@ pub trait Backend: 'static + Sized + MaybeSendSync { None } + /// Downcasts to the Metal backend, if applicable. + #[cfg(feature = "metal")] + fn as_metal(&self) -> Option<&Metal> { + None + } + /* * Module/function loading. */ @@ -303,6 +314,17 @@ pub trait Encoder { target_offset: usize, copy_len: usize, ) -> Result<(), B::Error>; + + /// Inserts a buffer-scope memory barrier into the active compute pass. + /// + /// Prefer the inherent [`GpuPass::memory_barrier`] method when working + /// with the type-erased backend; this trait method exists for backends + /// that implement [`Encoder`] directly. Backends that auto-insert + /// barriers between dispatches (e.g. WebGPU/wgpu) implement this as a + /// no-op. Backends that do *not* (e.g. Metal with + /// `MTLDispatchType::Concurrent`) emit a real barrier so subsequent + /// dispatches in the same pass see writes from previous dispatches. + fn memory_barrier(&mut self, _pass: &mut B::Pass) {} } /// An in-progress compute dispatch that collects bindings and launches kernels. diff --git a/crates/khal/src/backend/webgpu.rs b/crates/khal/src/backend/webgpu.rs index 112a891..f074eb1 100644 --- a/crates/khal/src/backend/webgpu.rs +++ b/crates/khal/src/backend/webgpu.rs @@ -22,6 +22,17 @@ use wgpu::{ PollError, Queue, ShaderModule, ShaderRuntimeChecks, ShaderStages, }; +/// Runtime checks used for `create_shader_module_trusted`. +/// +/// NOTE: we keep force_loop_boinding on to avoid what appears to be miscompilation of the +/// multibody kernels on some platforms (Windows native + Nvidia gpu). +fn shader_runtime_checks() -> ShaderRuntimeChecks { + ShaderRuntimeChecks { + force_loop_bounding: true, + ..ShaderRuntimeChecks::unchecked() + } +} + /// A WebGPU buffer slice that tracks its byte length for safe reinterpretation. #[derive(Clone, Copy)] pub struct WebGpuBufferSlice<'a> { @@ -221,7 +232,7 @@ impl WebGpu { label: None, source, }, - ShaderRuntimeChecks::unchecked(), + shader_runtime_checks(), ) }; Ok(shader_module) @@ -322,7 +333,7 @@ impl Backend for WebGpu { label: None, source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&data)), }, - ShaderRuntimeChecks::unchecked(), + shader_runtime_checks(), ) }; diff --git a/crates/khal/src/function.rs b/crates/khal/src/function.rs index bed0b26..9530c5c 100644 --- a/crates/khal/src/function.rs +++ b/crates/khal/src/function.rs @@ -58,6 +58,8 @@ impl GpuFunction { InnerGpuFunction::WebGpu(f) => GpuFunction::WebGpu(f, PhantomData), #[cfg(feature = "cuda")] InnerGpuFunction::Cuda(f) => GpuFunction::Cuda(f, PhantomData), + #[cfg(feature = "metal")] + InnerGpuFunction::Metal(f) => GpuFunction::Metal(f, PhantomData), InnerGpuFunction::Noop => GpuFunction::Noop(PhantomData), } } @@ -69,6 +71,8 @@ impl GpuFunction { GpuFunction::WebGpu(f, _) => InnerGpuFunction::WebGpu(f.clone()), #[cfg(feature = "cuda")] GpuFunction::Cuda(f, _) => InnerGpuFunction::Cuda(f.clone()), + #[cfg(feature = "metal")] + GpuFunction::Metal(f, _) => InnerGpuFunction::Metal(f.clone()), GpuFunction::Noop(_) => InnerGpuFunction::Noop, } } diff --git a/crates/khal/tests/metal_timestamps.rs b/crates/khal/tests/metal_timestamps.rs new file mode 100644 index 0000000..1a1b77e --- /dev/null +++ b/crates/khal/tests/metal_timestamps.rs @@ -0,0 +1,67 @@ +//! Smoke test: the Metal backend's timestamp pipeline end-to-end. +//! +//! Confirms that `MetalTimestamps::new` succeeds on this machine, that a +//! timed compute pass produces a non-zero duration, and that the labels +//! line up with what the host enqueued. + +#![cfg(feature = "metal")] + +use khal::backend::Metal; +use khal::backend::metal::MetalTimestamps; +use khal::backend::{Backend, Buffer, BufferUsages, Encoder}; + +#[test] +fn metal_timestamps_smoke() { + let metal = match Metal::new() { + Ok(m) => m, + Err(e) => { + eprintln!("Skipping: no Metal device available ({e})"); + return; + } + }; + if !metal.timestamp_supported() { + eprintln!("Skipping: device does not expose timestamp counter set"); + return; + } + + let mut timestamps = MetalTimestamps::new(&metal, 4) + .expect("MetalTimestamps::new returned None despite timestamp_supported()"); + + // A trivial GPU operation: a buffer-to-buffer copy big enough that the + // pass surrounding it has measurable duration. We don't actually need + // a compute dispatch here — just an encoder pass with samples on it. + // Use a couple of distinct passes to exercise multiple sample pairs. + let src = metal + .init_buffer::( + &vec![1.0_f32; 1024 * 1024], + BufferUsages::STORAGE | BufferUsages::COPY_SRC, + ) + .unwrap(); + let mut dst = metal + .uninit_buffer::(1024 * 1024, BufferUsages::STORAGE | BufferUsages::COPY_DST) + .unwrap(); + + let mut encoder = metal.begin_encoding(); + { + let _pass = encoder.begin_pass("first_pass", Some(&mut timestamps)); + } + encoder + .copy_buffer_to_buffer(&src, 0, &mut dst, 0, src.len()) + .unwrap(); + { + let _pass = encoder.begin_pass("second_pass", Some(&mut timestamps)); + } + metal.submit(encoder).unwrap(); + metal.synchronize().unwrap(); + + let entries = timestamps.read().expect("read timestamps"); + assert_eq!(entries.len(), 2, "expected two timed passes"); + assert_eq!(entries[0].label, "first_pass"); + assert_eq!(entries[1].label, "second_pass"); + // Timestamps may be ~zero for an empty pass, but they must be finite + // and non-negative. + for e in &entries { + assert!(e.duration_ms.is_finite(), "{} duration not finite", e.label); + assert!(e.duration_ms >= 0.0, "{} duration negative", e.label); + } +}