diff --git a/codex-feedback.md b/codex-feedback.md new file mode 100644 index 00000000..0b510794 --- /dev/null +++ b/codex-feedback.md @@ -0,0 +1,65 @@ +# Codex Feedback: Rust Live Audio Streaming Review + +## Outcome + +The live-streaming feature is **functionally working end-to-end**: + +**Microphone -> Rust SDK -> core.dll -> onnxruntime.dll / onnxruntime-genai.dll** + +The runtime path was validated (including device detection, session start/stop, and no native errors during streaming flow). + +--- + +## API Parity Comparison (Rust vs C#) + +### ✅ Matching areas + +1. Factory method exists in both SDKs: + - C#: `CreateLiveTranscriptionSession()` + - Rust: `create_live_transcription_session()` + +2. Core command flow is aligned: + - `audio_stream_start` + - `audio_stream_push` (binary payload path) + - `audio_stream_stop` + +3. Session lifecycle shape exists in both: + - start -> append/push -> stream results -> stop + +4. Settings coverage is aligned: + - sample rate, channels, bits per sample, language, queue capacity + +5. **[RESOLVED]** Cancellation semantics: + - Rust now accepts `Option` on `start()`, `append()`, `stop()` + - `stop()` uses cancel-safe pattern matching C# `StopAsync` + +6. **[RESOLVED]** Response surface shape: + - Rust response now has `content: Vec` with `text`/`transcript` fields + - Callers use `result.content[0].text` — identical to C# `Content[0].Text` + +7. **[RESOLVED]** Disposal contract: + - `Drop` performs synchronous best-effort `audio_stream_stop` + +--- + +### Remaining minor differences (by design) + +1. **Stream accessor is single-take** — Rust `get_transcription_stream()` moves the receiver out (one call per session). C# returns `IAsyncEnumerable` from the channel reader directly. Functionally equivalent. + +2. **Cancellation token type** — Rust uses `tokio_util::sync::CancellationToken`; C# uses `System.Threading.CancellationToken`. Both serve the same purpose with idiomatic patterns. + +--- + +## Reliability / Safety Notes + +1. FFI binary pointer handling for empty slices uses `std::ptr::null()` to avoid dangling-pointer risk. +2. Native session cleanup on drop includes best-effort `audio_stream_stop` to reduce leak risk. +3. Cancel-safe stop always completes native session cleanup even if cancellation fires. + +--- + +## Final Assessment + +- **Feature status**: Working +- **E2E path**: Verified (microphone → SDK → core.dll → ort-genai) +- **Parity status**: API-identical to C# (cancellation, response envelope, disposal) diff --git a/samples/rust/Cargo.toml b/samples/rust/Cargo.toml index 42d1293f..359882a5 100644 --- a/samples/rust/Cargo.toml +++ b/samples/rust/Cargo.toml @@ -4,6 +4,7 @@ members = [ "tool-calling-foundry-local", "native-chat-completions", "audio-transcription-example", + "live-audio-transcription-example", "tutorial-chat-assistant", "tutorial-document-summarizer", "tutorial-tool-calling", diff --git a/samples/rust/live-audio-transcription-example/Cargo.toml b/samples/rust/live-audio-transcription-example/Cargo.toml new file mode 100644 index 00000000..3694638f --- /dev/null +++ b/samples/rust/live-audio-transcription-example/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "live-audio-transcription-example" +version = "0.1.0" +edition = "2021" +description = "Live audio transcription (streaming) example using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" +cpal = "0.15" diff --git a/samples/rust/live-audio-transcription-example/src/main.rs b/samples/rust/live-audio-transcription-example/src/main.rs new file mode 100644 index 00000000..2c1011f6 --- /dev/null +++ b/samples/rust/live-audio-transcription-example/src/main.rs @@ -0,0 +1,277 @@ +// Live Audio Transcription — Foundry Local Rust SDK Example +// +// Demonstrates real-time microphone-to-text using: +// Microphone (cpal) → SDK → Core (NativeAOT DLL) → onnxruntime-genai (StreamingProcessor) +// +// Usage: +// cargo run # Live microphone transcription (press ENTER to stop) +// cargo run -- --synth # Use synthetic 440Hz sine wave instead of microphone + +use std::env; +use std::io::{self, Write}; +use std::sync::Arc; + +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; +use tokio_stream::StreamExt; + +const ALIAS: &str = "nemotron"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let use_synth = env::args().any(|a| a == "--synth"); + + println!("==========================================================="); + println!(" Foundry Local -- Live Audio Transcription Demo (Rust)"); + println!("==========================================================="); + println!(); + + // ── 1. Resolve e2e-test-pkgs path ──────────────────────────────────── + let exe_dir = env::current_exe()?.parent().unwrap().to_path_buf(); + + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let e2e_pkgs = std::path::PathBuf::from(manifest_dir) + .join("..") + .join("e2e-test-pkgs"); + + let (core_path, model_cache_dir) = if e2e_pkgs.exists() { + let core = e2e_pkgs + .canonicalize() + .expect("Failed to canonicalize e2e-test-pkgs path"); + let models = core.join("models"); + println!("Using e2e-test-pkgs:"); + println!(" Core DLLs: {}", core.display()); + println!(" Models: {}", models.display()); + ( + core.to_string_lossy().into_owned(), + models.to_string_lossy().into_owned(), + ) + } else { + println!("Using default paths (exe directory)"); + ( + exe_dir.to_string_lossy().into_owned(), + exe_dir.join("models").to_string_lossy().into_owned(), + ) + }; + + // ── 2. Initialise the manager ──────────────────────────────────────── + let config = FoundryLocalConfig::new("foundry_local_samples") + .library_path(&core_path) + .model_cache_dir(&model_cache_dir) + .additional_setting("Bootstrap", "false"); + + let manager = FoundryLocalManager::create(config)?; + println!("✓ FoundryLocalManager initialized\n"); + + // ── 3. Get the nemotron model ──────────────────────────────────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: f64| { + print!("\r {progress:.1}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 4. Create live transcription session ───────────────────────────── + let audio_client = model.create_audio_client(); + let session = Arc::new(audio_client.create_live_transcription_session()); + + println!("Starting live transcription session..."); + session.start(None).await?; + println!("✓ Session started\n"); + + // ── 5. Start reading transcription results in background ───────────── + let mut stream = session.get_transcription_stream()?; + let read_task = tokio::spawn(async move { + let mut count = 0usize; + while let Some(result) = stream.next().await { + match result { + Ok(r) => { + let text = &r.content[0].text; + if r.is_final { + println!(); + println!(" [FINAL] {text}"); + io::stdout().flush().ok(); + } else if !text.is_empty() { + print!("{text}"); + io::stdout().flush().ok(); + } + count += 1; + } + Err(e) => { + eprintln!("\n [ERROR] Stream error: {e}"); + break; + } + } + } + count + }); + + if use_synth { + // ── 6a. Synthetic audio mode ───────────────────────────────────── + println!("Generating synthetic PCM audio (440Hz sine wave, 3 seconds)...\n"); + + println!("==========================================================="); + println!(" PUSHING AUDIO → SDK → Core → onnxruntime-genai"); + println!("===========================================================\n"); + + let pcm_data = generate_sine_wave_pcm(16000, 3, 440.0); + let chunk_size = 16000 / 10 * 2; // 100ms chunks + let mut chunks_pushed = 0; + for offset in (0..pcm_data.len()).step_by(chunk_size) { + let end = std::cmp::min(offset + chunk_size, pcm_data.len()); + session.append(&pcm_data[offset..end], None).await?; + chunks_pushed += 1; + } + println!("Pushed {chunks_pushed} chunks ({} bytes)", pcm_data.len()); + } else { + // ── 6b. Live microphone mode ───────────────────────────────────── + let host = cpal::default_host(); + let device = host + .default_input_device() + .expect("No input audio device available"); + println!("Microphone: {}", device.name().unwrap_or_default()); + + // Query the device's default input config and adapt + let default_config = device.default_input_config()?; + println!( + "Device default: {} Hz, {} ch, {:?}", + default_config.sample_rate().0, + default_config.channels(), + default_config.sample_format() + ); + + let device_rate = default_config.sample_rate().0; + let device_channels = default_config.channels(); + let mic_config: cpal::StreamConfig = default_config.into(); + + let session_for_mic = Arc::clone(&session); + let rt = tokio::runtime::Handle::current(); + + // Build the stream with the device's native sample format (f32) + // and convert to 16kHz/16-bit/mono PCM for the SDK + let input_stream = device.build_input_stream( + &mic_config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + // Step 1: Mix to mono if stereo+ + let mono: Vec = if device_channels > 1 { + data.chunks(device_channels as usize) + .map(|frame| frame.iter().sum::() / device_channels as f32) + .collect() + } else { + data.to_vec() + }; + + // Step 2: Resample to 16kHz if device rate differs + let resampled = if device_rate != 16000 { + resample(&mono, device_rate, 16000) + } else { + mono + }; + + // Step 3: Convert f32 → i16 → little-endian bytes + let bytes: Vec = resampled + .iter() + .flat_map(|&s| { + let clamped = s.clamp(-1.0, 1.0); + let sample = (clamped * i16::MAX as f32) as i16; + sample.to_le_bytes() + }) + .collect(); + + if !bytes.is_empty() { + let session_ref = Arc::clone(&session_for_mic); + rt.spawn(async move { + if let Err(e) = session_ref.append(&bytes, None).await { + eprintln!("Append error: {e}"); + } + }); + } + }, + |err| eprintln!("Microphone stream error: {err}"), + None, + )?; + + input_stream.play()?; + + println!(); + println!("==========================================================="); + println!(" LIVE TRANSCRIPTION ACTIVE"); + println!(" Speak into your microphone."); + println!(" Transcription appears in real-time."); + println!(" Press ENTER to stop recording."); + println!("==========================================================="); + println!(); + + // Block until user presses ENTER + let mut line = String::new(); + io::stdin().read_line(&mut line)?; + + drop(input_stream); + println!("Microphone stopped."); + } + + // ── 7. Stop session and wait for results ───────────────────────────── + println!("\nStopping session (flushing remaining audio)..."); + session.stop(None).await?; + println!("✓ Session stopped\n"); + + let result_count = read_task.await?; + + println!("==========================================================="); + println!(" Total transcription results: {result_count}"); + println!("==========================================================="); + + // ── 8. Cleanup ─────────────────────────────────────────────────────── + println!("\nUnloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} + +/// Generate synthetic PCM audio (sine wave, 16kHz, 16-bit signed little-endian, mono). +fn generate_sine_wave_pcm(sample_rate: i32, duration_seconds: i32, frequency: f64) -> Vec { + let total_samples = (sample_rate * duration_seconds) as usize; + let mut pcm_bytes = vec![0u8; total_samples * 2]; + + for i in 0..total_samples { + let t = i as f64 / sample_rate as f64; + let sample = + (i16::MAX as f64 * 0.5 * (2.0 * std::f64::consts::PI * frequency * t).sin()) as i16; + let bytes = sample.to_le_bytes(); + pcm_bytes[i * 2] = bytes[0]; + pcm_bytes[i * 2 + 1] = bytes[1]; + } + + pcm_bytes +} + +/// Simple linear-interpolation resampler (e.g. 48kHz → 16kHz). +fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec { + if from_rate == to_rate { + return input.to_vec(); + } + let ratio = from_rate as f64 / to_rate as f64; + let out_len = (input.len() as f64 / ratio).ceil() as usize; + let mut output = Vec::with_capacity(out_len); + for i in 0..out_len { + let src_idx = i as f64 * ratio; + let idx = src_idx as usize; + let frac = src_idx - idx as f64; + let s0 = input[idx.min(input.len() - 1)]; + let s1 = input[(idx + 1).min(input.len() - 1)]; + output.push(s0 + (s1 - s0) * frac as f32); + } + output +} diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml index 2a6292b7..e089b802 100644 --- a/sdk/rust/Cargo.toml +++ b/sdk/rust/Cargo.toml @@ -21,6 +21,7 @@ serde_json = "1" thiserror = "2" tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } tokio-stream = "0.1" +tokio-util = "0.7" futures-core = "0.3" reqwest = { version = "0.12", features = ["json"] } urlencoding = "2" diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 43884d7f..0d17fe62 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -48,6 +48,19 @@ impl ResponseBuffer { } } +/// Request buffer with binary payload for `execute_command_with_binary`. +/// +/// Used for audio streaming — carries both JSON params and raw PCM bytes. +#[repr(C)] +struct StreamingRequestBuffer { + command: *const i8, + command_length: i32, + data: *const i8, + data_length: i32, + binary_data: *const u8, + binary_data_length: i32, +} + /// Signature for `execute_command`. type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer); @@ -63,6 +76,10 @@ type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( *mut std::ffi::c_void, ); +/// Signature for `execute_command_with_binary`. +type ExecuteCommandWithBinaryFn = + unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer); + // ── Library name helpers ───────────────────────────────────────────────────── #[cfg(target_os = "windows")] @@ -237,6 +254,8 @@ pub(crate) struct CoreInterop { CallbackFn, *mut std::ffi::c_void, ), + execute_command_with_binary: + Option, } impl std::fmt::Debug for CoreInterop { @@ -307,12 +326,22 @@ impl CoreInterop { *sym }; + // SAFETY: Same as above — symbol must match `ExecuteCommandWithBinaryFn`. + // Optional: older native cores may not export this symbol (used for audio streaming). + let execute_command_with_binary: Option = unsafe { + library + .get::(b"execute_command_with_binary\0") + .ok() + .map(|sym| *sym) + }; + Ok(Self { _library: library, #[cfg(target_os = "windows")] _dependency_libs, execute_command, execute_command_with_callback, + execute_command_with_binary, }) } @@ -354,6 +383,61 @@ impl CoreInterop { Self::process_response(response) } + /// Execute a command with an additional binary payload. + /// + /// Used for audio streaming — `binary_data` carries raw PCM bytes + /// alongside the JSON parameters. + pub fn execute_command_with_binary( + &self, + command: &str, + params: Option<&Value>, + binary_data: &[u8], + ) -> Result { + let native_fn = self.execute_command_with_binary.ok_or_else(|| { + FoundryLocalError::CommandExecution { + reason: "execute_command_with_binary is not supported by this native core \ + (symbol not found)" + .into(), + } + })?; + + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; + + let request = StreamingRequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + binary_data: if binary_data.is_empty() { + std::ptr::null() + } else { + binary_data.as_ptr() + }, + binary_data_length: binary_data.len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + // SAFETY: `request` fields point into `cmd`, `data_cstr`, and + // `binary_data` which are all alive for the duration of this call. + unsafe { + (native_fn)(&request, &mut response); + } + + Self::process_response(response) + } + /// Execute a command that streams results back via `callback`. /// /// Each chunk delivered by the native library is decoded as UTF-8 and diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 3a87a1c3..2329fcb1 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -14,6 +14,7 @@ use super::model_variant::ModelVariant; use crate::error::{FoundryLocalError, Result}; use crate::openai::AudioClient; use crate::openai::ChatClient; +use crate::openai::LiveAudioTranscriptionSession; use crate::types::ModelInfo; /// The public model type. @@ -242,6 +243,14 @@ impl Model { self.selected_variant().create_audio_client() } + /// Create a [`LiveAudioTranscriptionSession`] bound to the (selected) variant. + /// + /// Configure the session's [`settings`](LiveAudioTranscriptionSession::settings) + /// before calling [`start`](LiveAudioTranscriptionSession::start). + pub fn create_live_transcription_session(&self) -> LiveAudioTranscriptionSession { + self.selected_variant().create_live_transcription_session() + } + /// Available variants of this model. /// /// For a single-variant model (e.g. from diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index ca1a83c7..fcec3e4b 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -15,6 +15,7 @@ use crate::catalog::CacheInvalidator; use crate::error::Result; use crate::openai::AudioClient; use crate::openai::ChatClient; +use crate::openai::LiveAudioTranscriptionSession; use crate::types::ModelInfo; /// Represents one specific variant of a model (a particular id within an alias @@ -148,4 +149,8 @@ impl ModelVariant { pub(crate) fn create_audio_client(&self) -> AudioClient { AudioClient::new(&self.info.id, Arc::clone(&self.core)) } + + pub(crate) fn create_live_transcription_session(&self) -> LiveAudioTranscriptionSession { + LiveAudioTranscriptionSession::new(&self.info.id, Arc::clone(&self.core)) + } } diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index 872a875c..9fb4bb85 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -31,8 +31,10 @@ pub use async_openai::types::chat::{ // Re-export OpenAI response types for convenience. pub use crate::openai::{ - AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, - TranscriptionSegment, TranscriptionWord, + AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, ContentPart, + CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, TranscriptionSegment, + TranscriptionWord, }; pub use async_openai::types::chat::{ ChatChoice, ChatChoiceStream, ChatCompletionMessageToolCall, diff --git a/sdk/rust/src/openai/audio_client.rs b/sdk/rust/src/openai/audio_client.rs index 0319da38..cc1813d0 100644 --- a/sdk/rust/src/openai/audio_client.rs +++ b/sdk/rust/src/openai/audio_client.rs @@ -9,6 +9,7 @@ use crate::detail::core_interop::CoreInterop; use crate::error::{FoundryLocalError, Result}; use super::json_stream::JsonStream; +use super::live_audio_client::LiveAudioTranscriptionSession; /// A segment of a transcription, as returned by the OpenAI-compatible API. #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] @@ -196,6 +197,15 @@ impl AudioClient { Ok(AudioTranscriptionStream::new(rx)) } + /// Create a [`LiveAudioTranscriptionSession`] for real-time audio + /// streaming transcription. + /// + /// Configure the session's [`settings`](LiveAudioTranscriptionSession::settings) + /// before calling [`start`](LiveAudioTranscriptionSession::start). + pub fn create_live_transcription_session(&self) -> LiveAudioTranscriptionSession { + LiveAudioTranscriptionSession::new(&self.model_id, Arc::clone(&self.core)) + } + fn validate_path(path: &str) -> Result<()> { if path.trim().is_empty() { return Err(FoundryLocalError::Validation { diff --git a/sdk/rust/src/openai/live_audio_client.rs b/sdk/rust/src/openai/live_audio_client.rs new file mode 100644 index 00000000..0c83f01f --- /dev/null +++ b/sdk/rust/src/openai/live_audio_client.rs @@ -0,0 +1,737 @@ +//! Live audio transcription streaming session. +//! +//! Provides real-time audio streaming ASR (Automatic Speech Recognition). +//! Audio data from a microphone (or other source) is pushed in as PCM chunks +//! and transcription results are returned as an async [`Stream`](futures_core::Stream). +//! +//! # Example +//! +//! ```ignore +//! let audio_client = model.create_audio_client(); +//! let mut session = audio_client.create_live_transcription_session(); +//! session.settings.sample_rate = 16000; +//! session.settings.channels = 1; +//! session.settings.language = Some("en".into()); +//! +//! session.start(None).await?; +//! +//! // Push audio from microphone callback +//! session.append(&pcm_bytes, None).await?; +//! +//! // Read results as async stream +//! use tokio_stream::StreamExt; +//! let mut stream = session.get_transcription_stream()?; +//! while let Some(result) = stream.next().await { +//! let result = result?; +//! print!("{}", result.content[0].text); +//! } +//! +//! session.stop(None).await?; +//! ``` + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use serde_json::json; +use tokio_util::sync::CancellationToken; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; + +// ── Types ──────────────────────────────────────────────────────────────────── + +/// Audio format settings for a live transcription session. +/// +/// Must be configured before calling [`LiveAudioTranscriptionSession::start`]. +/// Settings are frozen once the session starts. +#[derive(Debug, Clone)] +pub struct LiveAudioTranscriptionOptions { + /// PCM sample rate in Hz. Default: 16000. + pub sample_rate: i32, + /// Number of audio channels. Default: 1 (mono). + pub channels: i32, + /// Number of bits per audio sample. Default: 16. + pub bits_per_sample: i32, + /// Optional BCP-47 language hint (e.g., `"en"`, `"zh"`). + pub language: Option, + /// Maximum number of audio chunks buffered in the internal push queue. + /// If the queue is full, [`LiveAudioTranscriptionSession::append`] will + /// wait asynchronously. + /// Default: 100 (~3 seconds of audio at typical chunk sizes). + pub push_queue_capacity: usize, +} + +impl Default for LiveAudioTranscriptionOptions { + fn default() -> Self { + Self { + sample_rate: 16000, + channels: 1, + bits_per_sample: 16, + language: None, + push_queue_capacity: 100, + } + } +} + +/// Internal raw deserialization target matching the native core's JSON format. +#[derive(Debug, Clone, serde::Deserialize)] +struct LiveAudioTranscriptionRaw { + #[serde(default)] + is_final: bool, + #[serde(default)] + text: String, + start_time: Option, + end_time: Option, +} + +/// A content part within a [`LiveAudioTranscriptionResponse`]. +/// +/// Mirrors the C# `ContentPart` shape from the OpenAI Realtime API so that +/// callers can access `result.content[0].text` or `result.content[0].transcript` +/// consistently across SDKs. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ContentPart { + /// The transcribed text. + pub text: String, + /// Same as `text` — provided for OpenAI Realtime API compatibility. + pub transcript: String, +} + +/// Transcription result from a live audio streaming session. +/// +/// Shaped to match the C# `LiveAudioTranscriptionResponse : ConversationItem` +/// so that callers access text via `result.content[0].text` or +/// `result.content[0].transcript`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct LiveAudioTranscriptionResponse { + /// Content parts — typically a single element. Access text via + /// `result.content[0].text` or `result.content[0].transcript`. + pub content: Vec, + /// Whether this is a final or partial (interim) result. + /// Nemotron models always return `true`; other models may return `false` + /// for interim hypotheses that will be replaced by a subsequent final result. + pub is_final: bool, + /// Start time offset of this segment in the audio stream (seconds). + pub start_time: Option, + /// End time offset of this segment in the audio stream (seconds). + pub end_time: Option, +} + +impl LiveAudioTranscriptionResponse { + /// Parse a transcription response from the native core's JSON format. + pub fn from_json(json: &str) -> Result { + let raw: LiveAudioTranscriptionRaw = serde_json::from_str(json)?; + Ok(Self::from_raw(raw)) + } + + fn from_raw(raw: LiveAudioTranscriptionRaw) -> Self { + Self { + content: vec![ContentPart { + transcript: raw.text.clone(), + text: raw.text, + }], + is_final: raw.is_final, + start_time: raw.start_time, + end_time: raw.end_time, + } + } +} + +/// Structured error response from the native core. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct CoreErrorResponse { + /// Error code (e.g. `"ASR_SESSION_NOT_FOUND"`). + pub code: String, + /// Human-readable error message. + pub message: String, + /// Whether this error is transient (retryable). + #[serde(rename = "isTransient", default)] + pub is_transient: bool, +} + +impl CoreErrorResponse { + /// Attempt to parse a native error string as structured JSON. + /// Returns `None` if the error is not valid JSON or doesn't match the schema. + pub fn try_parse(error_string: &str) -> Option { + serde_json::from_str(error_string).ok() + } +} + +// ── Stream type ────────────────────────────────────────────────────────────── + +/// An async stream of [`LiveAudioTranscriptionResponse`] items. +/// +/// Returned by [`LiveAudioTranscriptionSession::get_transcription_stream`]. +/// Implements [`futures_core::Stream`]. +pub struct LiveAudioTranscriptionStream { + rx: tokio::sync::mpsc::UnboundedReceiver>, +} + +impl Unpin for LiveAudioTranscriptionStream {} + +impl futures_core::Stream for LiveAudioTranscriptionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +// ── Session state ──────────────────────────────────────────────────────────── + +struct SessionState { + session_handle: Option, + started: bool, + stopped: bool, + push_tx: Option>>, + output_tx: Option>>, + output_rx: Option>>, + push_loop_handle: Option>, + active_settings: Option, +} + +impl SessionState { + fn new() -> Self { + Self { + session_handle: None, + started: false, + stopped: false, + push_tx: None, + output_tx: None, + output_rx: None, + push_loop_handle: None, + active_settings: None, + } + } +} + +// ── Session ────────────────────────────────────────────────────────────────── + +/// Session for real-time audio streaming ASR (Automatic Speech Recognition). +/// +/// Audio data from a microphone (or other source) is pushed in as PCM chunks +/// via [`append`](Self::append), and transcription results are returned as an +/// async [`Stream`](futures_core::Stream) via +/// [`get_transcription_stream`](Self::get_transcription_stream). +/// +/// Created via [`AudioClient::create_live_transcription_session`](super::AudioClient::create_live_transcription_session). +/// +/// # Thread safety +/// +/// [`append`](Self::append) can be called from any thread (including +/// high-frequency audio callbacks). Pushes are internally serialized via a +/// bounded channel to prevent unbounded memory growth and ensure ordering. +/// +/// # Cancellation +/// +/// All lifecycle methods accept an optional [`CancellationToken`]. Pass `None` +/// to use the default (no cancellation). +pub struct LiveAudioTranscriptionSession { + model_id: String, + core: Arc, + /// Audio format settings. Must be configured before calling [`start`](Self::start). + /// Settings are frozen once the session starts. + pub settings: LiveAudioTranscriptionOptions, + state: tokio::sync::Mutex, +} + +impl LiveAudioTranscriptionSession { + pub(crate) fn new(model_id: &str, core: Arc) -> Self { + Self { + model_id: model_id.to_owned(), + core, + settings: LiveAudioTranscriptionOptions::default(), + state: tokio::sync::Mutex::new(SessionState::new()), + } + } + + /// Start a real-time audio streaming session. + /// + /// Must be called before [`append`](Self::append) or + /// [`get_transcription_stream`](Self::get_transcription_stream). + /// Settings are frozen after this call. + /// + /// # Cancellation + /// + /// Pass a [`CancellationToken`] to abort the start operation. If + /// cancelled, the session is left in a clean (not-started) state. + pub async fn start(&self, ct: Option) -> Result<()> { + let mut state = self.state.lock().await; + + if state.started { + return Err(FoundryLocalError::Validation { + reason: "Streaming session already started. Call stop() first.".into(), + }); + } + + // Freeze settings + let active_settings = self.settings.clone(); + + // Create output channel (unbounded — only the push loop writes) + let (output_tx, output_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + + // Create push channel (bounded — backpressure if native core is slower than real-time) + let (push_tx, push_rx) = + tokio::sync::mpsc::channel::>(active_settings.push_queue_capacity); + + // Build request params + let mut params = serde_json::Map::new(); + params.insert("Model".into(), json!(self.model_id)); + params.insert( + "SampleRate".into(), + json!(active_settings.sample_rate.to_string()), + ); + params.insert( + "Channels".into(), + json!(active_settings.channels.to_string()), + ); + params.insert( + "BitsPerSample".into(), + json!(active_settings.bits_per_sample.to_string()), + ); + if let Some(ref lang) = active_settings.language { + params.insert("Language".into(), json!(lang)); + } + + let request = json!({ "Params": serde_json::Value::Object(params) }); + + // Start the native audio stream session (synchronous FFI on blocking thread) + let core = Arc::clone(&self.core); + let start_future = tokio::task::spawn_blocking(move || { + core.execute_command("audio_stream_start", Some(&request)) + }); + + let session_handle = if let Some(token) = &ct { + tokio::select! { + result = start_future => { + result.map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Start audio stream task join error: {e}"), + })?? + } + _ = token.cancelled() => { + return Err(FoundryLocalError::CommandExecution { + reason: "Start cancelled".into(), + }); + } + } + } else { + start_future + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Start audio stream task join error: {e}"), + })?? + }; + + if session_handle.is_empty() { + return Err(FoundryLocalError::CommandExecution { + reason: "Native core did not return a session handle.".into(), + }); + } + + state.session_handle = Some(session_handle.clone()); + state.started = true; + state.stopped = false; + state.active_settings = Some(active_settings); + + // Spawn the push loop on a blocking thread + let push_loop_core = Arc::clone(&self.core); + let push_loop_output_tx = output_tx.clone(); + let push_loop_handle = tokio::task::spawn_blocking(move || { + Self::push_loop(push_loop_core, session_handle, push_rx, push_loop_output_tx); + }); + + state.push_tx = Some(push_tx); + state.output_tx = Some(output_tx); + state.output_rx = Some(output_rx); + state.push_loop_handle = Some(push_loop_handle); + + Ok(()) + } + + /// Push a chunk of raw PCM audio data to the streaming session. + /// + /// Can be called from any async context (including high-frequency audio + /// callbacks when wrapped). Chunks are internally queued and serialized to + /// the native core. + /// + /// The data is copied internally so the caller can reuse the buffer. + /// + /// # Cancellation + /// + /// Pass a [`CancellationToken`] to abort if the push queue is full + /// (backpressure). The audio chunk will not be queued if cancelled. + pub async fn append(&self, pcm_data: &[u8], ct: Option) -> Result<()> { + let state = self.state.lock().await; + + if !state.started || state.stopped { + return Err(FoundryLocalError::Validation { + reason: "No active streaming session. Call start() first.".into(), + }); + } + + let tx = state + .push_tx + .as_ref() + .ok_or_else(|| FoundryLocalError::Internal { + reason: "Push channel missing".into(), + })?; + + // Copy the data to avoid issues if the caller reuses the buffer + let data = pcm_data.to_vec(); + + if let Some(token) = &ct { + tokio::select! { + result = tx.send(data) => { + result.map_err(|_| FoundryLocalError::CommandExecution { + reason: "Push channel closed — session may have been stopped".into(), + }) + } + _ = token.cancelled() => { + Err(FoundryLocalError::CommandExecution { + reason: "Append cancelled".into(), + }) + } + } + } else { + tx.send(data) + .await + .map_err(|_| FoundryLocalError::CommandExecution { + reason: "Push channel closed — session may have been stopped".into(), + }) + } + } + + /// Get the async stream of transcription results. + /// + /// Results arrive as the native ASR engine processes audio data. + /// Can only be called once per session (the receiver is moved out). + pub fn get_transcription_stream(&self) -> Result { + // We need to try_lock to avoid blocking — but in practice this is + // called from the same task that called start(). + let mut state = self + .state + .try_lock() + .map_err(|_| FoundryLocalError::Internal { + reason: "Could not acquire session lock for get_transcription_stream".into(), + })?; + + let rx = state + .output_rx + .take() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "No active streaming session, or stream already taken. \ + Call start() first and only call get_transcription_stream() once." + .into(), + })?; + + Ok(LiveAudioTranscriptionStream { rx }) + } + + /// Signal end-of-audio and stop the streaming session. + /// + /// Any remaining buffered audio in the push queue will be drained to the + /// native core first. Final results are delivered through the transcription + /// stream before it completes. + /// + /// # Cancellation safety + /// + /// Even if the provided [`CancellationToken`] fires, the native session + /// stop is still performed to avoid native session leaks (matching the C# + /// `StopAsync` cancellation-safe pattern). + pub async fn stop(&self, ct: Option) -> Result<()> { + let mut state = self.state.lock().await; + + if !state.started || state.stopped { + return Ok(()); // already stopped or never started + } + + state.stopped = true; + + // 1. Complete the push channel so the push loop drains remaining items + state.push_tx.take(); + + // 2. Wait for the push loop to finish draining + if let Some(handle) = state.push_loop_handle.take() { + let _ = handle.await; + } + + // 3. Tell native core to flush and finalize + let session_handle = state + .session_handle + .as_ref() + .ok_or_else(|| FoundryLocalError::Internal { + reason: "Session handle missing during stop".into(), + })? + .clone(); + + let params = json!({ + "Params": { + "SessionHandle": session_handle + } + }); + + let core = Arc::clone(&self.core); + let stop_future = tokio::task::spawn_blocking(move || { + core.execute_command("audio_stream_stop", Some(¶ms)) + }); + + // Even if ct fires, we MUST complete the native stop to avoid session leaks. + // This mirrors the C# StopAsync cancellation-safe pattern. + let stop_result = if let Some(token) = &ct { + tokio::select! { + result = stop_future => { + result.map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Stop audio stream task join error: {e}"), + })? + } + _ = token.cancelled() => { + // ct fired — retry without cancellation to prevent native session leak + let core_retry = Arc::clone(&self.core); + let params_retry = json!({ + "Params": { "SessionHandle": &session_handle } + }); + let retry_result = tokio::task::spawn_blocking(move || { + core_retry.execute_command("audio_stream_stop", Some(¶ms_retry)) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Stop audio stream retry task join error: {e}"), + })?; + + // Write final result before propagating cancellation + Self::write_final_result(&retry_result, &state); + state.output_tx.take(); + state.session_handle = None; + state.started = false; + + return Err(FoundryLocalError::CommandExecution { + reason: "Stop cancelled (native session stopped via best-effort cleanup)" + .into(), + }); + } + } + } else { + stop_future + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Stop audio stream task join error: {e}"), + })? + }; + + // Parse final transcription from stop response before completing the channel + Self::write_final_result(&stop_result, &state); + + // Complete the output channel + state.output_tx.take(); + state.session_handle = None; + state.started = false; + + // Propagate error if native stop failed + stop_result?; + + Ok(()) + } + + /// Write a final transcription result from a stop response into the output channel. + fn write_final_result(stop_result: &Result, state: &SessionState) { + if let Ok(data) = stop_result { + if !data.is_empty() { + if let Ok(raw) = serde_json::from_str::(data) { + if !raw.text.is_empty() { + if let Some(tx) = &state.output_tx { + let _ = tx.send(Ok(LiveAudioTranscriptionResponse::from_raw(raw))); + } + } + } + } + } + } + + /// Internal push loop — runs entirely on a blocking thread. + /// + /// Drains the push queue and sends chunks to the native core one at a time. + /// Terminates the session on any native error. + fn push_loop( + core: Arc, + session_handle: String, + mut push_rx: tokio::sync::mpsc::Receiver>, + output_tx: tokio::sync::mpsc::UnboundedSender>, + ) { + while let Some(audio_data) = push_rx.blocking_recv() { + let params = json!({ + "Params": { + "SessionHandle": &session_handle + } + }); + + let result = + core.execute_command_with_binary("audio_stream_push", Some(¶ms), &audio_data); + + match result { + Ok(data) if !data.is_empty() => { + match serde_json::from_str::(&data) { + Ok(raw) if !raw.text.is_empty() => { + let response = LiveAudioTranscriptionResponse::from_raw(raw); + let _ = output_tx.send(Ok(response)); + } + Ok(_) => {} // empty text — skip + Err(_) => {} // non-fatal parse error — skip + } + } + Ok(_) => {} // empty response — skip + Err(e) => { + // Fatal error from native core — terminate push loop + let error_info = CoreErrorResponse::try_parse(&format!("{e}")); + let code = error_info + .as_ref() + .map(|ei| ei.code.as_str()) + .unwrap_or("UNKNOWN"); + let _ = output_tx.send(Err(FoundryLocalError::CommandExecution { + reason: format!("Push failed (code={code}): {e}"), + })); + return; + } + } + } + // push_rx closed = push channel completed = push loop exits naturally + } +} + +// ── Drop impl ──────────────────────────────────────────────────────────────── + +impl Drop for LiveAudioTranscriptionSession { + fn drop(&mut self) { + if let Ok(mut state) = self.state.try_lock() { + // Close push channel to unblock the push loop + state.push_tx.take(); + state.output_tx.take(); + + // Best-effort native cleanup: call audio_stream_stop synchronously + // to prevent native session leaks. This is critical for long-running + // processes where users may forget to call stop(). + if state.started && !state.stopped { + if let Some(ref handle) = state.session_handle { + let params = serde_json::json!({ + "Params": { "SessionHandle": handle } + }); + // Synchronous FFI call — safe from Drop since execute_command + // is a blocking call that doesn't require an async runtime. + let _ = self + .core + .execute_command("audio_stream_stop", Some(¶ms)); + } + state.session_handle = None; + state.started = false; + state.stopped = true; + } + } + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // --- LiveAudioTranscriptionResponse::from_json tests --- + + #[test] + fn from_json_parses_text_and_is_final() { + let json = r#"{"is_final":true,"text":"hello world","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content.len(), 1); + assert_eq!(result.content[0].text, "hello world"); + assert_eq!(result.content[0].transcript, "hello world"); + assert!(result.is_final); + } + + #[test] + fn from_json_maps_timing_fields() { + let json = r#"{"is_final":false,"text":"partial","start_time":1.5,"end_time":3.0}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content[0].text, "partial"); + assert!(!result.is_final); + assert_eq!(result.start_time, Some(1.5)); + assert_eq!(result.end_time, Some(3.0)); + } + + #[test] + fn from_json_empty_text_parses_successfully() { + let json = r#"{"is_final":true,"text":"","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content[0].text, ""); + assert!(result.is_final); + } + + #[test] + fn from_json_only_start_time_sets_start_time() { + let json = r#"{"is_final":true,"text":"word","start_time":2.0,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.start_time, Some(2.0)); + assert_eq!(result.end_time, None); + assert_eq!(result.content[0].text, "word"); + } + + #[test] + fn from_json_invalid_json_returns_error() { + let result = LiveAudioTranscriptionResponse::from_json("not valid json"); + assert!(result.is_err()); + } + + #[test] + fn from_json_content_has_text_and_transcript() { + let json = r#"{"is_final":true,"text":"test","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + // Both Text and Transcript should have the same value + assert_eq!(result.content[0].text, "test"); + assert_eq!(result.content[0].transcript, "test"); + } + + // --- LiveAudioTranscriptionOptions tests --- + + #[test] + fn options_default_values() { + let options = LiveAudioTranscriptionOptions::default(); + + assert_eq!(options.sample_rate, 16000); + assert_eq!(options.channels, 1); + assert_eq!(options.bits_per_sample, 16); + assert_eq!(options.language, None); + assert_eq!(options.push_queue_capacity, 100); + } + + // --- CoreErrorResponse tests --- + + #[test] + fn core_error_response_try_parse_valid_json() { + let json = + r#"{"code":"ASR_SESSION_NOT_FOUND","message":"Session not found","isTransient":false}"#; + let error = CoreErrorResponse::try_parse(json).unwrap(); + + assert_eq!(error.code, "ASR_SESSION_NOT_FOUND"); + assert_eq!(error.message, "Session not found"); + assert!(!error.is_transient); + } + + #[test] + fn core_error_response_try_parse_invalid_json_returns_none() { + let result = CoreErrorResponse::try_parse("not json"); + assert!(result.is_none()); + } + + #[test] + fn core_error_response_try_parse_transient_error() { + let json = r#"{"code":"BUSY","message":"Model busy","isTransient":true}"#; + let error = CoreErrorResponse::try_parse(json).unwrap(); + + assert!(error.is_transient); + } +} diff --git a/sdk/rust/src/openai/mod.rs b/sdk/rust/src/openai/mod.rs index c3d4a645..80785a0c 100644 --- a/sdk/rust/src/openai/mod.rs +++ b/sdk/rust/src/openai/mod.rs @@ -1,6 +1,7 @@ mod audio_client; mod chat_client; mod json_stream; +mod live_audio_client; pub use self::audio_client::{ AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, @@ -8,3 +9,7 @@ pub use self::audio_client::{ }; pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; pub use self::json_stream::JsonStream; +pub use self::live_audio_client::{ + ContentPart, CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, +}; diff --git a/sdk/rust/tests/integration/live_audio_test.rs b/sdk/rust/tests/integration/live_audio_test.rs new file mode 100644 index 00000000..e375cc0c --- /dev/null +++ b/sdk/rust/tests/integration/live_audio_test.rs @@ -0,0 +1,106 @@ +use super::common; +use std::sync::Arc; +use tokio_stream::StreamExt; + +/// Generate synthetic PCM audio (440Hz sine wave, 16kHz, 16-bit mono). +fn generate_sine_wave_pcm(sample_rate: i32, duration_seconds: i32, frequency: f64) -> Vec { + let total_samples = (sample_rate * duration_seconds) as usize; + let mut pcm_bytes = vec![0u8; total_samples * 2]; // 16-bit = 2 bytes per sample + + for i in 0..total_samples { + let t = i as f64 / sample_rate as f64; + let sample = + (i16::MAX as f64 * 0.5 * (2.0 * std::f64::consts::PI * frequency * t).sin()) as i16; + pcm_bytes[i * 2] = (sample & 0xFF) as u8; + pcm_bytes[i * 2 + 1] = ((sample >> 8) & 0xFF) as u8; + } + + pcm_bytes +} + +// --- E2E streaming test with synthetic PCM audio --- + +#[tokio::test] +async fn live_streaming_e2e_with_synthetic_pcm_returns_valid_response() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + // Try to get a nemotron or whisper model for audio streaming + let model = match catalog.get_model("nemotron").await { + Ok(m) => m, + Err(_) => match catalog.get_model(common::WHISPER_MODEL_ALIAS).await { + Ok(m) => m, + Err(_) => { + eprintln!("Skipping E2E test: no audio model available"); + return; + } + }, + }; + + if !model.is_cached().await.unwrap_or(false) { + eprintln!("Skipping E2E test: model not cached"); + return; + } + + model.load().await.expect("model.load() failed"); + + let audio_client = model.create_audio_client(); + let session = audio_client.create_live_transcription_session(); + + // Verify default settings + assert_eq!(session.settings.sample_rate, 16000); + assert_eq!(session.settings.channels, 1); + assert_eq!(session.settings.bits_per_sample, 16); + + if let Err(e) = session.start(None).await { + eprintln!("Skipping E2E test: could not start session: {e}"); + model.unload().await.ok(); + return; + } + + // Start collecting results in background (must start before pushing audio) + let mut stream = session + .get_transcription_stream() + .expect("get_transcription_stream failed"); + + let results = Arc::new(tokio::sync::Mutex::new(Vec::new())); + let results_clone = Arc::clone(&results); + let read_task = tokio::spawn(async move { + while let Some(result) = stream.next().await { + match result { + Ok(r) => results_clone.lock().await.push(r), + Err(e) => { + eprintln!("Stream error: {e}"); + break; + } + } + } + }); + + // Generate ~2 seconds of synthetic PCM audio (440Hz sine wave) + let pcm_bytes = generate_sine_wave_pcm(16000, 2, 440.0); + + // Push audio in chunks (100ms each, matching typical mic callback size) + let chunk_size = 16000 / 10 * 2; // 100ms of 16-bit audio = 3200 bytes + for offset in (0..pcm_bytes.len()).step_by(chunk_size) { + let end = std::cmp::min(offset + chunk_size, pcm_bytes.len()); + session + .append(&pcm_bytes[offset..end], None) + .await + .expect("append failed"); + } + + // Stop session to flush remaining audio and complete the stream + session.stop(None).await.expect("stop failed"); + read_task.await.expect("read task failed"); + + // Verify response attributes — synthetic audio may or may not produce text, + // but the response objects should be properly structured (C#-compatible envelope) + let results = results.lock().await; + for result in results.iter() { + assert!(!result.content.is_empty(), "content must not be empty"); + assert_eq!(result.content[0].text, result.content[0].transcript); + } + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk/rust/tests/integration/main.rs b/sdk/rust/tests/integration/main.rs index 04de9a23..c320b7d6 100644 --- a/sdk/rust/tests/integration/main.rs +++ b/sdk/rust/tests/integration/main.rs @@ -11,6 +11,7 @@ mod common; mod audio_client_test; mod catalog_test; mod chat_client_test; +mod live_audio_test; mod manager_test; mod model_test; mod web_service_test;