diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0bd4ef89..476d2535 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,7 @@ jobs: - name: build run: eng/ci.sh - name: Run tests - run: sudo target/x86_64-unknown-linux-musl/release/avml --compress output.lime + run: sudo target/x86_64-unknown-linux-musl/release/avml acquire --compress output.lime - name: upload artifacts uses: actions/upload-artifact@v7.0.1 with: @@ -46,8 +46,6 @@ jobs: path: | target/*/release/avml target/*/release/avml-minimal - target/*/release/avml-convert - target/*/release/avml-upload arm64: permissions: contents: read @@ -58,7 +56,7 @@ jobs: - name: build run: eng/ci.sh - name: Run tests - run: sudo target/aarch64-unknown-linux-musl/release/avml --compress output.lime + run: sudo target/aarch64-unknown-linux-musl/release/avml acquire --compress output.lime - name: upload artifacts uses: actions/upload-artifact@v7.0.1 with: @@ -66,8 +64,6 @@ jobs: path: | target/*/release/avml target/*/release/avml-minimal - target/*/release/avml-convert - target/*/release/avml-upload windows: permissions: contents: read @@ -79,10 +75,8 @@ jobs: git config --global core.eol lf - uses: actions/checkout@v6.0.2 - uses: Swatinem/rust-cache@65012b490220f477f20ab979e35ae732e6de4e68 # v2 - - name: build avml-convert - run: cargo build --release --bin avml-convert --locked - - name: build avml-upload - run: cargo build --release --bin avml-upload --locked + - name: build avml + run: cargo build --release --bin avml --locked - name: Run tests run: cargo test - name: upload artifacts @@ -90,7 +84,5 @@ jobs: with: name: windows-artifacts path: | - target/release/avml-convert.exe - target/release/avml_convert.pdb - target/release/avml-upload.exe - target/release/avml_upload.pdb + target/release/avml.exe + target/release/avml.pdb diff --git a/Cargo.toml b/Cargo.toml index c1f40172..1885cf27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,14 @@ edition = "2024" rust-version = "1.88.0" [features] -default = ["put", "blobstore", "native-tls"] +default = ["stream", "upload", "convert", "native-tls"] put = ["dep:reqwest", "reqwest?/stream", "dep:url", "dep:tokio", "dep:tokio-util", "dep:futures"] -blobstore = ["dep:url", "dep:azure_core", "dep:azure_storage_blob", "dep:tokio", "dep:async-trait", "dep:futures"] +blobstore = ["dep:url", "dep:azure_core", "dep:azure_storage_blob", "dep:tokio", "dep:async-trait", "dep:futures", "dep:tokio-util", "tokio/sync", "tokio-util/io-util"] status = ["dep:indicatif"] native-tls = ["dep:native-tls", "azure_core?/reqwest", "reqwest?/native-tls-vendored"] +convert = [] +upload = ["put", "blobstore"] +stream = ["blobstore", "tokio/net"] [dependencies] async-trait = {version="0.1", optional=true} @@ -34,7 +37,7 @@ azure_storage_blob = {version="1.0", optional=true, default-features=false, feat indicatif = {version="0.18", optional=true, default-features=false} native-tls = {version="0.2", features=["vendored"], optional=true, default-features=false} reqwest = {version="0.13", optional=true, default-features=false} -tokio = {version="1.52", default-features=false, optional=true, features=["fs", "rt-multi-thread", "io-util", "macros"]} +tokio = {version="1.52", default-features=false, optional=true, features=["fs", "rt-multi-thread", "io-util", "macros", "net"]} tokio-util = {version="0.7", features=["codec"], optional=true, default-features=false} url = {version="2.5", optional=true, default-features=false} @@ -69,5 +72,5 @@ panic="abort" codegen-units=1 [[bin]] -name = "avml-upload" -required-features = ["put", "blobstore"] +name = "avml" +path = "src/bin/avml/main.rs" diff --git a/README.md b/README.md index 251ae568..897aec71 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,20 @@ If the memory source is not specified on the commandline, AVML will iterate over * Oracle Linux: 6.8, 6.9, 6.10, 7.3, 7.4, 7.5, 7.6, 7.9, 8.5, 9.0 * [CBL-Mariner](https://github.com/microsoft/CBL-Mariner): 1.0, 2.0 +## Subcommands + +`avml` is a single binary with subcommands. Each subcommand is gated by +a Cargo feature so a minimal build only includes the capability you need: + +| Subcommand | Feature | Default | What it does | +|------------|-----------|---------|----------------------------------------------------------------| +| `acquire` | (always) | yes | Snapshot memory to a local file (optional upload after). | +| `convert` | `convert` | yes | Convert between AVML / LiME / raw formats. | +| `upload` | `upload` | yes | Upload a local file via HTTP PUT or to Azure Block Blob. | +| `stream` | `stream` | yes | Stream a snapshot directly to a destination, no local file. | + +Build a minimal acquire-only binary with `cargo build --release --no-default-features`. + # Getting Started ## Capturing a compressed memory image @@ -39,7 +53,7 @@ If the memory source is not specified on the commandline, AVML will iterate over On the target host: ``` -avml --compress output.lime.compressed +avml acquire --compress output.lime.compressed ``` ## Capturing an uncompressed memory image @@ -47,7 +61,7 @@ avml --compress output.lime.compressed On the target host: ``` -avml output.lime +avml acquire output.lime ``` ## Capturing a memory image & uploading to Azure Blob Store @@ -60,7 +74,58 @@ SAS_URL=$(az storage blob generate-sas --account-name ACCOUNT --container CONTAI On the target host, execute avml with the generated SAS token. ``` -avml --sas-url ${SAS_URL} --delete output.lime +avml acquire --sas-url ${SAS_URL} --delete output.lime +``` + +## Streaming a memory image without writing to local disk + +For hosts where writing the snapshot to a local file first is undesirable +(read-only root, limited disk, forensic chain-of-custody concerns), use +the `stream` subcommand. It picks the memory source once up front (same +preference order as `acquire`'s `/dev/stdout` path — `/proc/kcore`, then +`/dev/crash`, then `/dev/mem`; pass `--source` to override) and writes +bytes sequentially to the chosen destination. The source cannot be +changed mid-stream, so there is no automatic source fallback. + +### To Azure Block Blob Storage + +``` +avml stream blob ${SAS_URL} +``` + +- The block size is derived automatically so the snapshot fits within + Azure's per-blob 50,000-block limit. `--sas-block-size` (MiB) acts as + a *floor*; if the derived minimum is larger, the larger value wins. +- `--sas-block-concurrency` caps the number of in-flight `stage_block` + calls. Peak RAM is approximately `(concurrency + 1) * block_size`. +- If the snapshot fails mid-upload, staged blocks are abandoned without + being committed; Azure discards them automatically per its standard + policy. + +### To a remote TCP listener + +On the collector host: + +``` +nc -l 9000 > snapshot.lime +``` + +On the target host: + +``` +avml stream tcp collector.example.com:9000 +``` + +avml connects once and writes the snapshot sequentially. If the +connection drops mid-stream, the snapshot aborts; there is no resume. +No TLS — pair with an SSH tunnel or stunnel for confidentiality and +integrity if needed. + +## Uploading a previously-captured snapshot + +``` +avml upload put ./output.lime ${URL} # HTTP PUT +avml upload blob ./output.lime ${SAS_URL} # Azure Block Blob ``` ## Capturing a memory image of an Azure VM using VM Extensions @@ -71,7 +136,7 @@ On a secure host with `az cli` credentials, do the following: 2. Create `config.json` containing the following information: ``` { - "commandToExecute": "./avml --compress --sas-url --delete", + "commandToExecute": "./avml acquire --compress --sas-url --delete", "fileUris": ["https://FULL.URL.TO.AVML.example.com/avml"] } ``` @@ -85,73 +150,36 @@ On a secure host, generate a [S3 pre-signed URL](https://docs.aws.amazon.com/cli On the target host, execute avml with the generated pre-signed URL. ``` -avml --put ${URL} --delete output.lime +avml acquire --url ${URL} --delete output.lime ``` ## To decompress an AVML-compressed image ``` -avml-convert ./compressed.lime ./uncompressed.lime +avml convert ./compressed.lime ./uncompressed.lime ``` ## To compress an uncompressed LiME image ``` -avml-convert --source-format lime --format lime_compressed ./uncompressed.lime ./compressed.lime +avml convert --source-format lime --format lime_compressed ./uncompressed.lime ./compressed.lime ``` # Usage ``` -A portable volatile memory acquisition tool - -Usage: avml [OPTIONS] - -Arguments: - - name of the file to write to on local system - -Options: - --compress - compress via snappy - - --source - specify input source +A portable volatile memory acquisition tool for Linux - Possible values: - - /dev/crash: - Provides a read-only view of physical memory. Access to memory using this device must be paged aligned and read one page at a time - - /dev/mem: - Provides a read-write view of physical memory, though AVML opens it in a read-only fashion. Access to to memory using this device can be disabled using the kernel configuration options `CONFIG_STRICT_DEVMEM` or `CONFIG_IO_STRICT_DEVMEM` - - /proc/kcore: - Provides a virtual ELF coredump of kernel memory. This can be used to access physical memory +Usage: avml - --max-disk-usage - Specify the maximum estimated disk usage (in MB) - - --max-disk-usage-percentage - Specify the maximum estimated disk usage to stay under - - --url - upload via HTTP PUT upon acquisition - - --delete - delete upon successful upload - - --sas-url - upload via Azure Blob Store upon acquisition - - --sas-block-size - specify maximum block size in MiB; must be greater than 0 - - --sas-block-concurrency - specify blob upload concurrency; must be greater than 0 - - -h, --help - Print help (see a summary with '-h') - - -V, --version - Print version +Commands: + acquire Acquire a memory snapshot to a local file (and optionally upload it) + convert Convert between AVML and LiME snapshot formats and a raw memory image + upload Upload an already-acquired snapshot file to remote storage + stream Stream a memory snapshot directly to remote storage, without writing it to a local file + help Print this message or the help of the given subcommand(s) ``` +Run `avml --help` for per-command options. + # Building on Ubuntu # Install MUSL diff --git a/eng/build.sh b/eng/build.sh index d27ce5ce..d79c7be4 100755 --- a/eng/build.sh +++ b/eng/build.sh @@ -18,8 +18,5 @@ done cargo +stable build --release --no-default-features --target ${ARCH}-unknown-linux-musl --locked cp target/${ARCH}-unknown-linux-musl/release/avml target/${ARCH}-unknown-linux-musl/release/avml-minimal cargo +stable build --release --target ${ARCH}-unknown-linux-musl --locked -cargo +stable build --release --target ${ARCH}-unknown-linux-musl --locked --bin avml-upload --features "put blobstore status" strip target/${ARCH}-unknown-linux-musl/release/avml strip target/${ARCH}-unknown-linux-musl/release/avml-minimal -strip target/${ARCH}-unknown-linux-musl/release/avml-convert -strip target/${ARCH}-unknown-linux-musl/release/avml-upload diff --git a/eng/test-azure-image.sh b/eng/test-azure-image.sh index ba6b3976..ed4dcbd8 100755 --- a/eng/test-azure-image.sh +++ b/eng/test-azure-image.sh @@ -45,6 +45,6 @@ IP=$(az vm create -g ${GROUP} --size ${SIZE} -n ${VM} --image ${SKU} --public-ip ssh-keygen -R ${IP} 2>/dev/null > /dev/null quiet scp -oStrictHostKeyChecking=no ${EXE} ${IP}:./avml quiet ssh -oStrictHostKeyChecking=no ${IP} sudo chmod +x avml -quiet ssh -oStrictHostKeyChecking=no ${IP} sudo ./avml --compress /mnt/image.lime +quiet ssh -oStrictHostKeyChecking=no ${IP} sudo ./avml acquire --compress /mnt/image.lime quiet ssh -oStrictHostKeyChecking=no ${IP} sudo chmod a+r /mnt/image.lime quiet scp -oStrictHostKeyChecking=no ${IP}:/mnt/image.lime ./${SKU}.lime diff --git a/src/bin/avml.rs b/src/bin/avml/acquire.rs similarity index 53% rename from src/bin/avml.rs rename to src/bin/avml/acquire.rs index 27430984..4d2842f2 100644 --- a/src/bin/avml.rs +++ b/src/bin/avml/acquire.rs @@ -3,17 +3,15 @@ use avml::{Format, Result, Snapshot, Source, iomem}; use clap::Parser; -#[cfg(feature = "blobstore")] +#[cfg(feature = "upload")] use core::num::NonZeroUsize; use core::{num::NonZeroU64, ops::Range}; use std::path::PathBuf; -#[cfg(any(feature = "blobstore", feature = "put"))] +#[cfg(feature = "upload")] use {avml::Error, tokio::fs::remove_file, url::Url}; #[derive(Parser)] -/// A portable volatile memory acquisition tool for Linux -#[command(author, version, about, long_about = None)] -struct Config { +pub struct Args { /// compress via snappy #[arg(long)] compress: bool, @@ -31,27 +29,27 @@ struct Config { max_disk_usage_percentage: Option, /// upload via HTTP PUT upon acquisition - #[cfg(feature = "put")] + #[cfg(feature = "upload")] #[arg(long)] url: Option, /// delete upon successful upload - #[cfg(any(feature = "blobstore", feature = "put"))] + #[cfg(feature = "upload")] #[arg(long)] delete: bool, /// upload via Azure Blob Store upon acquisition - #[cfg(feature = "blobstore")] + #[cfg(feature = "upload")] #[arg(long)] sas_url: Option, /// specify maximum block size in MiB; must be greater than 0 - #[cfg(feature = "blobstore")] + #[cfg(feature = "upload")] #[arg(long)] sas_block_size: Option, /// specify blob upload concurrency; must be greater than 0 - #[cfg(feature = "blobstore")] + #[cfg(feature = "upload")] #[arg(long)] sas_block_concurrency: Option, @@ -75,31 +73,38 @@ fn disk_usage_percentage(s: &str) -> core::result::Result { } } -#[cfg(any(feature = "blobstore", feature = "put"))] -async fn upload(config: &Config) -> Result<()> { - let mut delete = false; +pub fn run(args: &Args) -> Result<()> { + let format = Format::from(args.compress); - #[cfg(feature = "put")] - { - if let Some(ref url) = config.url { - avml::put(&config.filename, url).await?; - delete = true; - } + let ranges = iomem::parse()?; + let snapshot = Snapshot::new(&args.filename, ranges) + .source(args.source.clone()) + .max_disk_usage_percentage(args.max_disk_usage_percentage) + .max_disk_usage(args.max_disk_usage) + .format(format); + snapshot.create()?; + Ok(()) +} + +#[cfg(feature = "upload")] +pub async fn upload_after_acquire(args: &Args) -> Result<()> { + let mut did_upload = false; + + if let Some(ref url) = args.url { + avml::put(&args.filename, url).await?; + did_upload = true; } - #[cfg(feature = "blobstore")] - { - if let Some(ref sas_url) = config.sas_url { - let uploader = avml::BlobUploader::new(sas_url)? - .block_size(config.sas_block_size) - .concurrency(config.sas_block_concurrency); - uploader.upload_file(&config.filename).await?; - delete = true; - } + if let Some(ref sas_url) = args.sas_url { + let uploader = avml::BlobUploader::new(sas_url)? + .block_size(args.sas_block_size) + .concurrency(args.sas_block_concurrency); + uploader.upload_file(&args.filename).await?; + did_upload = true; } - if delete && config.delete { - remove_file(&config.filename) + if did_upload && args.delete { + remove_file(&args.filename) .await .map_err(|source| Error::Io { context: "unable to remove snapshot", @@ -109,30 +114,3 @@ async fn upload(config: &Config) -> Result<()> { Ok(()) } - -fn acquire(config: &Config) -> Result<()> { - let format = Format::from(config.compress); - - let ranges = iomem::parse()?; - let snapshot = Snapshot::new(&config.filename, ranges) - .source(config.source.clone()) - .max_disk_usage_percentage(config.max_disk_usage_percentage) - .max_disk_usage(config.max_disk_usage) - .format(format); - snapshot.create()?; - Ok(()) -} - -#[cfg(not(any(feature = "blobstore", feature = "put")))] -fn main() -> Result<()> { - let config = Config::parse(); - acquire(&config) -} - -#[cfg(any(feature = "blobstore", feature = "put"))] -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<()> { - let config = Config::parse(); - acquire(&config)?; - upload(&config).await -} diff --git a/src/bin/avml-convert.rs b/src/bin/avml/convert.rs similarity index 92% rename from src/bin/avml-convert.rs rename to src/bin/avml/convert.rs index a2f32701..b60bfb8e 100644 --- a/src/bin/avml-convert.rs +++ b/src/bin/avml/convert.rs @@ -10,6 +10,50 @@ use std::{ path::{Path, PathBuf}, }; +#[derive(Parser)] +pub struct Args { + /// specify input format + #[arg(long, value_enum, default_value_t = CliFormat::LimeCompressed)] + source_format: CliFormat, + + /// specify output format + #[arg(long, value_enum, default_value_t = CliFormat::Lime)] + format: CliFormat, + + /// name of the source file to read from on local system + src: PathBuf, + + /// name of the destination file to write to on local system + dst: PathBuf, +} + +#[derive(ValueEnum, Clone, Copy, PartialEq, Eq)] +enum CliFormat { + Raw, + Lime, + #[value(rename_all = "snake_case")] + LimeCompressed, +} + +pub fn run(args: &Args) -> Result<()> { + match (args.source_format, args.format) { + (CliFormat::Lime | CliFormat::LimeCompressed, CliFormat::Raw) => { + convert_to_raw(&args.src, &args.dst) + } + (CliFormat::Lime, CliFormat::LimeCompressed) => { + convert(&args.src, &args.dst, Format::AvmlCompressed) + } + (CliFormat::LimeCompressed, CliFormat::Lime) => convert(&args.src, &args.dst, Format::Lime), + (CliFormat::Raw, CliFormat::Lime) => convert_from_raw(&args.src, &args.dst, Format::Lime), + (CliFormat::Raw, CliFormat::LimeCompressed) => { + convert_from_raw(&args.src, &args.dst, Format::AvmlCompressed) + } + (CliFormat::Lime, CliFormat::Lime) + | (CliFormat::LimeCompressed, CliFormat::LimeCompressed) + | (CliFormat::Raw, CliFormat::Raw) => Err(Error::NoConversionRequired), + } +} + fn convert(src: &Path, dst: &Path, format: Format) -> Result<()> { let src_len = metadata(src) .map_err(|source| image::Error::Io { @@ -147,61 +191,9 @@ fn convert_from_raw(src: &Path, dst: &Path, format: Format) -> Result<()> { encode_raw_image(&mut image, src_len) } -#[derive(Parser)] -/// AVML compress/decompress tool -#[command(version)] -struct Config { - /// specify output format - #[arg(long, value_enum, default_value_t = CliFormat::LimeCompressed)] - source_format: CliFormat, - - /// specify output format - #[arg(long, value_enum, default_value_t = CliFormat::Lime)] - format: CliFormat, - - /// name of the source file to read to on local system - src: PathBuf, - - /// name of the destination file to write to on local system - dst: PathBuf, -} - -#[derive(ValueEnum, Clone, Copy, PartialEq, Eq)] -enum CliFormat { - Raw, - Lime, - #[value(rename_all = "snake_case")] - LimeCompressed, -} - -fn main() -> Result<()> { - let config = Config::parse(); - - match (config.source_format, config.format) { - (CliFormat::Lime | CliFormat::LimeCompressed, CliFormat::Raw) => { - convert_to_raw(&config.src, &config.dst) - } - (CliFormat::Lime, CliFormat::LimeCompressed) => { - convert(&config.src, &config.dst, Format::AvmlCompressed) - } - (CliFormat::LimeCompressed, CliFormat::Lime) => { - convert(&config.src, &config.dst, Format::Lime) - } - (CliFormat::Raw, CliFormat::Lime) => { - convert_from_raw(&config.src, &config.dst, Format::Lime) - } - (CliFormat::Raw, CliFormat::LimeCompressed) => { - convert_from_raw(&config.src, &config.dst, Format::AvmlCompressed) - } - (CliFormat::Lime, CliFormat::Lime) - | (CliFormat::LimeCompressed, CliFormat::LimeCompressed) - | (CliFormat::Raw, CliFormat::Raw) => Err(Error::NoConversionRequired), - } -} - #[cfg(test)] mod tests { - use crate::{convert_image, convert_to_raw_image, encode_raw_image}; + use super::{convert_image, convert_to_raw_image, encode_raw_image}; use avml::{Format, Result, image}; use rand::{Rng as _, SeedableRng as _, rngs::SmallRng}; use std::io::Cursor; diff --git a/src/bin/avml/main.rs b/src/bin/avml/main.rs new file mode 100644 index 00000000..b685a8ac --- /dev/null +++ b/src/bin/avml/main.rs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use avml::Result; +use clap::{Parser, Subcommand}; + +// `acquire` and `stream` both depend on Linux kernel interfaces +// (/proc/iomem, /proc/kcore, /dev/crash, /dev/mem). They're absent +// on non-Linux targets; on macOS / BSD / Windows the binary ships +// only `convert` and `upload` (whichever features the user enabled). +#[cfg(target_os = "linux")] +mod acquire; +#[cfg(feature = "convert")] +mod convert; +#[cfg(all(feature = "stream", target_os = "linux"))] +mod stream; +#[cfg(feature = "upload")] +mod upload; + +/// A portable volatile memory acquisition tool for Linux. +#[derive(Parser)] +#[command(author, version, long_about = None)] +struct Cmd { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Acquire a memory snapshot to a local file (and optionally upload it). + #[cfg(target_os = "linux")] + Acquire(acquire::Args), + + /// Convert between AVML and `LiME` snapshot formats and a raw memory image. + #[cfg(feature = "convert")] + Convert(convert::Args), + + /// Upload an already-acquired snapshot file to remote storage. + #[cfg(feature = "upload")] + #[command(subcommand)] + Upload(upload::Commands), + + /// Stream a memory snapshot directly to remote storage, without + /// writing it to a local file. + #[cfg(all(feature = "stream", target_os = "linux"))] + #[command(subcommand)] + Stream(stream::Commands), +} + +#[cfg(not(any(feature = "stream", feature = "upload")))] +fn main() -> Result<()> { + let cmd = Cmd::parse(); + match cmd.command { + #[cfg(target_os = "linux")] + Commands::Acquire(args) => acquire::run(&args), + #[cfg(feature = "convert")] + Commands::Convert(args) => convert::run(&args), + } +} + +#[cfg(any(feature = "stream", feature = "upload"))] +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<()> { + let cmd = Cmd::parse(); + match cmd.command { + #[cfg(target_os = "linux")] + Commands::Acquire(args) => { + acquire::run(&args)?; + #[cfg(feature = "upload")] + acquire::upload_after_acquire(&args).await?; + Ok(()) + } + #[cfg(feature = "convert")] + Commands::Convert(args) => convert::run(&args), + #[cfg(feature = "upload")] + Commands::Upload(sub) => upload::run(sub).await, + #[cfg(all(feature = "stream", target_os = "linux"))] + Commands::Stream(sub) => stream::run(sub).await, + } +} diff --git a/src/bin/avml/stream.rs b/src/bin/avml/stream.rs new file mode 100644 index 00000000..a90d341b --- /dev/null +++ b/src/bin/avml/stream.rs @@ -0,0 +1,211 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use avml::{BLOB_MAX_BLOCKS, BlobError, BlockBlobStream, Format, Result, Snapshot, Source, iomem}; +use azure_storage_blob::BlobClient; +use clap::{Parser, Subcommand}; +use core::{ + num::{NonZeroU64, NonZeroUsize}, + ops::Range, +}; +use std::path::PathBuf; +use tokio_util::io::SyncIoBridge; +use url::Url; + +#[derive(Subcommand)] +pub enum Commands { + /// Stream to Azure Block Blob Storage via `stage_block` + `commit_block_list`. + Blob(BlobArgs), + + /// Stream to a remote TCP listener (e.g. `nc -l PORT > snapshot.lime`). + /// + /// The destination is opened with a single `connect`; on connection + /// failure mid-stream the snapshot aborts without retry. + Tcp(TcpArgs), +} + +#[derive(Parser)] +pub struct BlobArgs { + /// compress via snappy + #[arg(long)] + compress: bool, + + /// specify input source. If unset, the source is probed once at + /// start (kcore, then /dev/crash, then /dev/mem); the choice cannot + /// be changed once any bytes have been written. + #[arg(long, value_enum)] + source: Option, + + /// SAS URL identifying the destination Block Blob. + sas_url: Url, + + /// minimum block size in MiB. The actual block size may be larger + /// if needed to keep the total block count below Azure's 50,000 + /// limit. + #[arg(long)] + sas_block_size: Option, + + /// maximum number of in-flight `stage_block` calls. + #[arg(long)] + sas_block_concurrency: Option, +} + +#[derive(Parser)] +pub struct TcpArgs { + /// compress via snappy + #[arg(long)] + compress: bool, + + /// specify input source. If unset, the source is probed once at + /// start (kcore, then /dev/crash, then /dev/mem); the choice cannot + /// be changed once any bytes have been written. + #[arg(long, value_enum)] + source: Option, + + /// destination TCP listener as host:port. Hostnames are resolved. + addr: String, +} + +pub async fn run(cmd: Commands) -> Result<()> { + match cmd { + Commands::Blob(args) => stream_blob(args).await, + Commands::Tcp(args) => stream_tcp(args).await, + } +} + +async fn stream_blob(args: BlobArgs) -> Result<()> { + let ranges = iomem::parse()?; + let block_size = derive_block_size(&ranges, args.sas_block_size)?; + let concurrency = args + .sas_block_concurrency + .unwrap_or(avml::DEFAULT_CONCURRENCY); + + let block_client = BlobClient::new(args.sas_url, None, None) + .map_err(BlobError::from)? + .block_blob_client(); + let format = Format::from(args.compress); + + let source = match args.source { + Some(s) => s, + None => Snapshot::probe_single_source().map_err(avml::Error::from)?, + }; + + let stream = BlockBlobStream::new(block_client, block_size, concurrency); + + let (stream, result) = tokio::task::spawn_blocking( + move || -> (BlockBlobStream, core::result::Result<(), avml::Error>) { + let mut stream = stream; + // Snapshot::create_to_writer never inspects `destination`; + // any in-scope path satisfies the &Path borrow. + let dummy = PathBuf::from("/dev/null"); + let snapshot = Snapshot::new(&dummy, ranges) + .source(Some(source)) + .format(format); + let r: core::result::Result<(), avml::Error> = snapshot + .create_to_writer(stream.writer()) + .map_err(avml::Error::from) + .and_then(|()| { + stream.finish_writes().map_err(|io_err| avml::Error::Io { + context: "unable to finish blob stream", + source: io_err, + }) + }); + (stream, r) + }, + ) + .await + .map_err(|e| avml::Error::Io { + context: "spawn_blocking join failed", + source: std::io::Error::other(e.to_string()), + })?; + + match result { + Ok(()) => stream.finalize().await.map_err(avml::Error::from), + Err(e) => { + drop(stream.abort().await); + Err(e) + } + } +} + +fn derive_block_size( + ranges: &[Range], + user_floor_mib: Option, +) -> Result { + /// 5 MiB — Azure's recommended minimum for high-throughput block blobs. + const STREAM_MIN_BLOCK_SIZE: u64 = 5 * 1024 * 1024; + /// 4000 MiB — Azure's documented per-block maximum. + const STREAM_MAX_BLOCK_SIZE: u64 = 4000 * 1024 * 1024; + /// Leave headroom below Azure's 50,000-block hard cap so we don't + /// trip the limit on a slightly-over-estimate. + const BLOCK_COUNT_HEADROOM: u64 = 1000; + + let estimate = ranges + .iter() + .map(|r| r.end.saturating_sub(r.start)) + .fold(0_u64, u64::saturating_add) + .saturating_add(100 * 1024 * 1024); // overhead for headers + worst-case compression + + let target_block_count = BLOB_MAX_BLOCKS.saturating_sub(BLOCK_COUNT_HEADROOM).max(1); + let derived_min = estimate.div_ceil(target_block_count); + + let user_floor_bytes = + user_floor_mib.map_or(0, |mib| mib.get().saturating_mul(1024).saturating_mul(1024)); + + let block_size = derived_min + .max(STREAM_MIN_BLOCK_SIZE) + .max(user_floor_bytes) + .min(STREAM_MAX_BLOCK_SIZE); + + if estimate > STREAM_MAX_BLOCK_SIZE.saturating_mul(target_block_count) { + return Err(avml::Error::Blob(BlobError::TooLarge)); + } + + NonZeroUsize::new(usize::try_from(block_size).map_err(|_| avml::Error::Io { + context: "block size doesn't fit in usize", + source: std::io::Error::other("block size overflow"), + })?) + .ok_or_else(|| avml::Error::Io { + context: "block size derivation produced zero", + source: std::io::Error::other("derived zero block size"), + }) +} + +async fn stream_tcp(args: TcpArgs) -> Result<()> { + let ranges = iomem::parse()?; + let format = Format::from(args.compress); + let source = match args.source { + Some(s) => s, + None => Snapshot::probe_single_source().map_err(avml::Error::from)?, + }; + + let socket = tokio::net::TcpStream::connect(&args.addr) + .await + .map_err(|io_err| avml::Error::Io { + context: "unable to connect to TCP destination", + source: io_err, + })?; + let mut bridge = SyncIoBridge::new(socket); + + tokio::task::spawn_blocking(move || -> Result<()> { + // Snapshot::create_to_writer never inspects `destination`; + // any in-scope path satisfies the &Path borrow. + let dummy = PathBuf::from("/dev/null"); + let snapshot = Snapshot::new(&dummy, ranges) + .source(Some(source)) + .format(format); + snapshot + .create_to_writer(&mut bridge) + .map_err(avml::Error::from)?; + bridge.shutdown().map_err(|io_err| avml::Error::Io { + context: "unable to finish TCP stream", + source: io_err, + })?; + Ok(()) + }) + .await + .map_err(|e| avml::Error::Io { + context: "spawn_blocking join failed", + source: std::io::Error::other(e.to_string()), + })? +} diff --git a/src/bin/avml-upload.rs b/src/bin/avml/upload.rs similarity index 75% rename from src/bin/avml-upload.rs rename to src/bin/avml/upload.rs index 62a9edbd..36529569 100644 --- a/src/bin/avml-upload.rs +++ b/src/bin/avml/upload.rs @@ -2,51 +2,40 @@ // Licensed under the MIT License. use avml::{BlobUploader, Result, put}; -use clap::{Parser, Subcommand}; +use clap::Subcommand; use core::num::{NonZeroU64, NonZeroUsize}; use std::path::PathBuf; use url::Url; -#[derive(Parser)] -#[command(version)] -/// AVML upload tool -struct Cmd { - #[command(subcommand)] - command: Commands, -} - #[derive(Subcommand)] -enum Commands { +pub enum Commands { + /// Upload a local file via HTTP PUT. Put { /// name of the file to upload on the local system filename: PathBuf, - /// url to upload via HTTP PUT url: Url, }, - UploadBlob { + + /// Upload a local file to Azure Block Blob Storage. + Blob { /// name of the file to upload on the local system filename: PathBuf, - - /// url to upload via Azure Blob Storage + /// SAS URL identifying the destination Block Blob url: Url, - /// specify blob upload concurrency; must be greater than 0 #[arg(long)] sas_block_concurrency: Option, - /// specify maximum block size in MiB; must be greater than 0 #[arg(long)] sas_block_size: Option, }, } -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<()> { - let cmd = Cmd::parse(); - match cmd.command { +pub async fn run(cmd: Commands) -> Result<()> { + match cmd { Commands::Put { filename, url } => put(&filename, &url).await?, - Commands::UploadBlob { + Commands::Blob { filename, url, sas_block_size, diff --git a/src/image.rs b/src/image.rs index da42a3f9..245daef1 100644 --- a/src/image.rs +++ b/src/image.rs @@ -312,25 +312,7 @@ impl Image { src_filename: &Path, dst_filename: &Path, ) -> Result> { - let src_filename = canonicalize(src_filename).map_err(|source| Error::Io { - context: "unable to canonicalize path", - source, - })?; - let align_src = [ - Path::new("/dev/crash"), - Path::new("/dev/mem"), - Path::new("/proc/kcore"), - ] - .contains(&src_filename.as_path()); - - let src = OpenOptions::new() - .read(true) - .open(&src_filename) - .map_err(|source| Error::Io { - context: "unable to open memory source", - source, - })?; - + let (src, align_src) = open_src(src_filename)?; let dst = Self::open_dst(dst_filename)?; Ok(Image:: { @@ -341,6 +323,27 @@ impl Image { }) } + /// Open `src_filename` for reading and use `dst` as the destination + /// writer. Suitable for streaming the image to a non-file sink such as + /// `BlockBlobStream`. + /// + /// # Errors + /// Returns an error if the source file cannot be opened or canonicalized. + pub fn with_dst( + format: Format, + src_filename: &Path, + dst: W2, + ) -> Result> { + let (src, align_src) = open_src(src_filename)?; + + Ok(Image:: { + format, + align_src, + src, + dst, + }) + } + /// Writes multiple memory blocks to the destination file. /// /// # Errors @@ -495,6 +498,29 @@ fn range_usize(value: Range) -> Result { Ok(usize::try_from(value.end.saturating_sub(value.start))?) } +fn open_src(src_filename: &Path) -> Result<(File, bool)> { + let src_filename = canonicalize(src_filename).map_err(|source| Error::Io { + context: "unable to canonicalize path", + source, + })?; + let align_src = [ + Path::new("/dev/crash"), + Path::new("/dev/mem"), + Path::new("/proc/kcore"), + ] + .contains(&src_filename.as_path()); + + let src = OpenOptions::new() + .read(true) + .open(&src_filename) + .map_err(|source| Error::Io { + context: "unable to open memory source", + source, + })?; + + Ok((src, align_src)) +} + #[cfg(test)] mod tests { use super::{Format, Header, Image}; diff --git a/src/lib.rs b/src/lib.rs index 48ce70b7..1a0a2cac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,11 @@ mod snapshot; mod upload; #[cfg(feature = "blobstore")] -pub use crate::upload::blobstore::{BlobUploader, DEFAULT_CONCURRENCY}; +pub use crate::upload::blobstore::{BlobUploader, DEFAULT_CONCURRENCY, Error as BlobError}; #[cfg(feature = "put")] pub use crate::upload::http::put; +#[cfg(feature = "blobstore")] +pub use crate::upload::stream::{BLOB_MAX_BLOCKS, BlockBlobStream}; pub use crate::{ errors::Error, image::Format, diff --git a/src/snapshot.rs b/src/snapshot.rs index d5ff6f4b..4e6c16f5 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -277,18 +277,8 @@ impl<'a> Snapshot<'a> { if let Some(ref src) = self.source { self.create_source(src)?; } else if self.destination == Path::new("/dev/stdout") { - // If we're writing to stdout, we can't start over if reading from a - // source fails. As such, we need to do more work to pick a source - // rather than just trying all available options. - if is_kcore_ok() { - self.create_source(&Source::ProcKcore)?; - } else if can_open(Path::new("/dev/crash")) { - self.create_source(&Source::DevCrash)?; - } else if can_open(Path::new("/dev/mem")) { - self.create_source(&Source::DevMem)?; - } else { - return Err(Error::NoSourceAvailable); - } + let src = Self::probe_single_source()?; + self.create_source(&src)?; } else { let crash = match self.create_source(&Source::DevCrash) { Ok(()) => return Ok(()), @@ -316,6 +306,61 @@ impl<'a> Snapshot<'a> { Ok(()) } + /// Probe for an available source without trying multiple. Used when the + /// destination cannot be rewound (`/dev/stdout`, streaming blob upload). + /// + /// Preference order matches the historical `/dev/stdout` branch: kcore, + /// then `/dev/crash`, then `/dev/mem`. + /// + /// # Errors + /// Returns `NoSourceAvailable` if none of the three probes succeed. + pub fn probe_single_source() -> Result { + if is_kcore_ok() { + Ok(Source::ProcKcore) + } else if can_open(Path::new("/dev/crash")) { + Ok(Source::DevCrash) + } else if can_open(Path::new("/dev/mem")) { + Ok(Source::DevMem) + } else { + Err(Error::NoSourceAvailable) + } + } + + /// Stream a memory snapshot to an arbitrary writer. + /// + /// Unlike [`Self::create`], this does **not** auto-retry across + /// sources. The destination is not assumed to be rewindable, so once + /// any bytes are written we cannot fall back to a different source. + /// The caller must either supply a [`Source`] via [`Self::source`] or + /// rely on [`Self::probe_single_source`] to pick one up front. + /// + /// `max_disk_usage` and `max_disk_usage_percentage` are ignored: with + /// no local disk involvement the limits don't apply. The caller is + /// expected to enforce any blob-side size limits separately. + /// + /// # Errors + /// Returns an error if: + /// - No source is available + /// - There is a failure reading from the source + /// - Writing to `dst` fails + pub fn create_to_writer(&self, dst: W) -> Result<()> { + let source = match self.source { + Some(ref s) => s.clone(), + None => Self::probe_single_source()?, + }; + + match source { + Source::ProcKcore => self.kcore_to_writer(dst), + Source::DevCrash => self.phys_to_writer(Path::new("/dev/crash"), dst), + Source::DevMem => self.phys_to_writer(Path::new("/dev/mem"), dst), + Source::Raw(ref s) => self.phys_to_writer(s, dst), + } + .map_err(|e| Error::UnableToCreateSnapshotFromSource { + src: source, + source: Box::new(e), + }) + } + // given a set of ranges from iomem and a set of Blocks derived from the // pseudo-elf phys section headers, derive a set of ranges that can be used // to create a snapshot. @@ -396,7 +441,22 @@ impl<'a> Snapshot<'a> { let mut image = Image::::new(self.format, Path::new("/proc/kcore"), self.destination)?; self.check_disk_usage(&image)?; + Self::write_kcore_blocks(&mut image, &self.memory_ranges) + } + fn kcore_to_writer(&self, dst: W) -> Result<()> { + if !is_kcore_ok() { + return Err(Error::LockedDownKcore); + } + + let mut image = Image::::with_dst(self.format, Path::new("/proc/kcore"), dst)?; + Self::write_kcore_blocks(&mut image, &self.memory_ranges) + } + + fn write_kcore_blocks( + image: &mut Image, + memory_ranges: &[Range], + ) -> Result<()> { let file = elf::ElfStream::::open_stream(&mut image.src)?; let physical_ranges = Self::physical_ranges_from_segments(file.segments()); @@ -405,11 +465,11 @@ impl<'a> Snapshot<'a> { "no usable PT_LOAD segments in /proc/kcore", )); } - if self.memory_ranges.is_empty() { + if memory_ranges.is_empty() { return Err(Error::KcoreParse("no initial memory range")); } - let blocks = Self::find_kcore_blocks(&self.memory_ranges, &physical_ranges); + let blocks = Self::find_kcore_blocks(memory_ranges, &physical_ranges); image.write_blocks(&blocks)?; Ok(()) } @@ -456,9 +516,23 @@ impl<'a> Snapshot<'a> { } fn phys(&self, mem: &Path) -> Result<()> { + let blocks = Self::phys_blocks(mem, &self.memory_ranges); + let mut image = Image::::new(self.format, mem, self.destination)?; + self.check_disk_usage(&image)?; + image.write_blocks(&blocks)?; + Ok(()) + } + + fn phys_to_writer(&self, mem: &Path, dst: W) -> Result<()> { + let blocks = Self::phys_blocks(mem, &self.memory_ranges); + let mut image = Image::::with_dst(self.format, mem, dst)?; + image.write_blocks(&blocks)?; + Ok(()) + } + + fn phys_blocks(mem: &Path, memory_ranges: &[Range]) -> Vec { let is_crash = mem == Path::new("/dev/crash"); - let blocks = self - .memory_ranges + memory_ranges .iter() .map(|x| Block { offset: x.start, @@ -468,14 +542,7 @@ impl<'a> Snapshot<'a> { x.start..x.end }, }) - .collect::>(); - - let mut image = Image::::new(self.format, mem, self.destination)?; - self.check_disk_usage(&image)?; - - image.write_blocks(&blocks)?; - - Ok(()) + .collect::>() } } diff --git a/src/upload/mod.rs b/src/upload/mod.rs index b6d8a384..93c66915 100644 --- a/src/upload/mod.rs +++ b/src/upload/mod.rs @@ -4,6 +4,9 @@ #[cfg(feature = "blobstore")] pub mod blobstore; +#[cfg(feature = "blobstore")] +pub mod stream; + #[cfg(feature = "put")] pub mod http; diff --git a/src/upload/stream.rs b/src/upload/stream.rs new file mode 100644 index 00000000..ea0937fd --- /dev/null +++ b/src/upload/stream.rs @@ -0,0 +1,774 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Stream a memory snapshot directly to an Azure Block Blob. +//! +//! Bytes are buffered into fixed-size blocks. Each full block is staged via +//! [`BlockBlobClient::stage_block`]. Concurrency across staged blocks is +//! bounded by a [`Semaphore`]. After the snapshot writer is finished, the +//! caller invokes [`BlockBlobStream::finalize`] which awaits any in-flight +//! stage operations and commits the block list. On failure, the caller +//! invokes [`BlockBlobStream::abort`], which awaits in-flight tasks but +//! does not commit; uncommitted blocks are discarded by Azure on its own +//! timeline. + +use crate::upload::blobstore::Error; +use async_trait::async_trait; +use azure_core::{ + Bytes, + http::{NoFormat, RequestContent, XmlFormat}, +}; +use azure_storage_blob::{ + BlockBlobClient, + models::{ + BlockBlobClientCommitBlockListOptions, BlockBlobClientStageBlockOptions, BlockLookupList, + }, +}; +use core::num::NonZeroUsize; +use core::{ + pin::Pin, + task::{Context, Poll}, +}; +use std::{ + io::{Result as IoResult, Write}, + sync::{Arc, Mutex}, +}; +use tokio::{ + io::AsyncWrite, + runtime::Handle, + sync::{Semaphore, mpsc}, + task::JoinHandle, +}; +use tokio_util::io::SyncIoBridge; + +type Result = core::result::Result; + +/// Block IDs as a fixed 8-byte big-endian representation of a u64 counter. +/// Azure requires all block IDs within a single commit to have identical +/// byte length; using the raw `to_be_bytes()` representation guarantees +/// that and produces an ordering that matches the staging order when +/// compared lexicographically. +fn block_id(index: u64) -> Vec { + index.to_be_bytes().to_vec() +} + +/// Abstraction over the two `BlockBlobClient` methods this module uses, +/// so tests can substitute an in-memory fake without standing up Azure. +#[expect( + clippy::redundant_pub_crate, + reason = "appears in the signature of pub(crate) BlockBlobStream::with_stager" +)] +#[async_trait] +pub(crate) trait BlockStager: Send + Sync + 'static { + async fn stage_block(&self, block_id: Vec, body: Bytes) -> Result<()>; + async fn commit_block_list(&self, block_ids: Vec>) -> Result<()>; +} + +/// Live `BlockStager` backed by `azure_storage_blob`. +struct SdkStager { + client: Arc, +} + +#[async_trait] +impl BlockStager for SdkStager { + async fn stage_block(&self, block_id: Vec, body: Bytes) -> Result<()> { + let len = u64::try_from(body.len())?; + let content: RequestContent = body.into(); + self.client + .stage_block( + &block_id, + len, + content, + Option::>::None, + ) + .await?; + Ok(()) + } + + async fn commit_block_list(&self, block_ids: Vec>) -> Result<()> { + let list = BlockLookupList { + latest: Some(block_ids), + ..Default::default() + }; + let content: RequestContent = list.try_into()?; + self.client + .commit_block_list( + content, + Option::>::None, + ) + .await?; + Ok(()) + } +} + +/// Messages from the writer to the uploader task. +enum UploaderMsg { + Stage { index: u64, data: Bytes }, +} + +/// Final result returned by the uploader task once the writer side closes +/// the channel. +struct UploaderResult { + /// Successfully staged block indices, in arbitrary order. + completed: Vec, + /// First error observed across all `stage_block` calls. + first_error: Option, +} + +type ReservationFuture = Pin< + Box< + dyn Future< + Output = core::result::Result< + mpsc::OwnedPermit, + mpsc::error::SendError<()>, + >, + > + Send, + >, +>; + +/// Sync writer side: implements [`AsyncWrite`] by buffering up to +/// `block_size` bytes, then handing the buffer off to the uploader task +/// via a bounded mpsc. +/// +/// `poll_write` returns `Pending` when the channel is full, which gives +/// the producer real backpressure: the bound is `concurrency`, matching +/// the semaphore inside the uploader. +struct BlockBlobAsyncWriter { + sender: Option>, + buf: Vec, + block_size: usize, + next_index: u64, + /// `Some` once the uploader has observed (or the writer has observed + /// via a closed channel) that the receiver is gone. After that point + /// `poll_write` returns the captured error. + error_slot: Arc>>, + /// Pending reservation across `poll_write` invocations. + pending_reservation: Option, +} + +impl BlockBlobAsyncWriter { + fn new( + sender: mpsc::Sender, + block_size: NonZeroUsize, + error_slot: Arc>>, + ) -> Self { + Self { + sender: Some(sender), + buf: Vec::with_capacity(block_size.get()), + block_size: block_size.get(), + next_index: 0, + error_slot, + pending_reservation: None, + } + } + + fn take_first_error(&self) -> Option { + if let Ok(mut slot) = self.error_slot.lock() { + slot.take() + } else { + None + } + } + + /// Try to dispatch the current buffer if it is full. Returns + /// `Poll::Pending` if a permit can't be acquired without waiting, + /// or `Poll::Ready(Ok(()))` if the buffer was either not full or + /// successfully dispatched. + fn try_dispatch(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.buf.len() < self.block_size { + return Poll::Ready(Ok(())); + } + + if self.pending_reservation.is_none() { + let Some(sender) = self.sender.as_ref() else { + return Poll::Ready(Err(std::io::Error::other( + "blob writer was already shut down", + ))); + }; + let sender = sender.clone(); + self.pending_reservation = Some(Box::pin(sender.reserve_owned())); + } + + let mut reservation = self + .pending_reservation + .take() + .ok_or_else(|| std::io::Error::other("missing reservation slot"))?; + match reservation.as_mut().poll(cx) { + Poll::Pending => { + self.pending_reservation = Some(reservation); + Poll::Pending + } + Poll::Ready(Err(_send)) => { + // Receiver dropped -> uploader exited (probably due to error). + self.sender = None; + let err = self + .take_first_error() + .unwrap_or_else(|| Error::Io(std::io::Error::other("uploader exited early"))); + Poll::Ready(Err(std::io::Error::other(err.to_string()))) + } + Poll::Ready(Ok(permit)) => { + let index = self.next_index; + self.next_index = self.next_index.saturating_add(1); + let block_size = self.block_size; + let data = core::mem::replace(&mut self.buf, Vec::with_capacity(block_size)); + permit.send(UploaderMsg::Stage { + index, + data: Bytes::from(data), + }); + Poll::Ready(Ok(())) + } + } + } +} + +impl AsyncWrite for BlockBlobAsyncWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(err) = self.take_first_error() { + return Poll::Ready(Err(std::io::Error::other(err.to_string()))); + } + + if self.buf.len() >= self.block_size { + match self.try_dispatch(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => {} + } + } + + let take = self + .block_size + .saturating_sub(self.buf.len()) + .min(buf.len()); + self.buf.extend_from_slice(buf.get(..take).unwrap_or(&[])); + Poll::Ready(Ok(take)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // Partial flush would emit a short block in the middle of the + // blob, which Azure forbids. Flush is a no-op; the trailing + // partial buffer is staged only by poll_shutdown. + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(err) = self.take_first_error() { + return Poll::Ready(Err(std::io::Error::other(err.to_string()))); + } + + // Stage the trailing partial buffer, if any, as the final block. + if !self.buf.is_empty() { + if self.pending_reservation.is_none() { + let Some(sender) = self.sender.as_ref() else { + return Poll::Ready(Ok(())); + }; + let sender = sender.clone(); + self.pending_reservation = Some(Box::pin(sender.reserve_owned())); + } + let mut reservation = self + .pending_reservation + .take() + .ok_or_else(|| std::io::Error::other("missing reservation slot"))?; + match reservation.as_mut().poll(cx) { + Poll::Pending => { + self.pending_reservation = Some(reservation); + return Poll::Pending; + } + Poll::Ready(Err(_)) => { + self.sender = None; + let err = self.take_first_error().unwrap_or_else(|| { + Error::Io(std::io::Error::other("uploader exited early")) + }); + return Poll::Ready(Err(std::io::Error::other(err.to_string()))); + } + Poll::Ready(Ok(permit)) => { + let index = self.next_index; + self.next_index = self.next_index.saturating_add(1); + let block_size = self.block_size; + let data = core::mem::replace(&mut self.buf, Vec::with_capacity(block_size)); + permit.send(UploaderMsg::Stage { + index, + data: Bytes::from(data), + }); + } + } + } + + // Drop the sender so the uploader task can finish. + self.sender = None; + Poll::Ready(Ok(())) + } +} + +/// Public handle for streaming a memory snapshot into a Block Blob. +/// +/// Construct with [`BlockBlobStream::new`]. Drive the sync writer +/// (returned by [`Self::writer`]) from a blocking context — typically +/// inside [`tokio::task::spawn_blocking`]. After writing finishes, +/// call [`Self::finish_writes`] from the same blocking context to flush +/// the trailing partial block; then call [`Self::finalize`] (or +/// [`Self::abort`]) from async context to await uploads and commit (or +/// discard) the block list. +/// +/// # Runtime requirements +/// +/// `new` and `finalize`/`abort` must be invoked from inside a tokio +/// runtime. The sync writer must be invoked from a thread that is *not* +/// a runtime worker (i.e., from `spawn_blocking`), otherwise the +/// internal `block_on` deadlocks the current-thread runtime that the +/// avml binary uses. +pub struct BlockBlobStream { + bridge: SyncIoBridge, + uploader: Option>, + stager: Arc, + /// Maximum number of blocks Azure will accept in a single commit. + /// Enforced when the uploader assigns indices. + max_blocks: u64, +} + +/// Azure's per-blob block count limit. Public for callers (e.g. the +/// binary) that want to derive a safe block size up front. +pub const BLOB_MAX_BLOCKS: u64 = 50_000; + +impl BlockBlobStream { + /// Construct a streaming uploader against a live block blob. + #[must_use] + pub fn new( + client: BlockBlobClient, + block_size: NonZeroUsize, + concurrency: NonZeroUsize, + ) -> Self { + Self::with_stager( + Arc::new(SdkStager { + client: Arc::new(client), + }), + block_size, + concurrency, + ) + } + + pub(crate) fn with_stager( + stager: Arc, + block_size: NonZeroUsize, + concurrency: NonZeroUsize, + ) -> Self { + let handle = Handle::current(); + let error_slot = Arc::new(Mutex::new(None)); + let (tx, rx) = mpsc::channel::(concurrency.get()); + + let uploader = handle.spawn(run_uploader( + stager.clone(), + rx, + Arc::new(Semaphore::new(concurrency.get())), + error_slot.clone(), + )); + + let writer = BlockBlobAsyncWriter::new(tx, block_size, error_slot); + let bridge = SyncIoBridge::new_with_handle(writer, handle); + + Self { + bridge, + uploader: Some(uploader), + stager, + max_blocks: BLOB_MAX_BLOCKS, + } + } + + /// Returns the sync writer to feed into the snapshot pipeline. + /// Must be driven from a blocking thread. + pub fn writer(&mut self) -> &mut dyn Write { + &mut self.bridge + } + + /// Flush any partial trailing block. Must be called from the same + /// blocking thread that drove the writer. + /// + /// # Errors + /// Returns an error if the underlying [`AsyncWrite::poll_shutdown`] + /// returns one (e.g., the uploader task already exited due to an + /// upload failure). + pub fn finish_writes(&mut self) -> IoResult<()> { + self.bridge.shutdown() + } + + /// Await all in-flight `stage_block` calls and commit the block list. + /// Consumes `self` because no further writes are valid after commit. + /// + /// # Errors + /// Returns any captured `stage_block` error or any error from + /// `commit_block_list`. + pub async fn finalize(mut self) -> Result<()> { + let result = self.await_uploader().await?; + if let Some(err) = result.first_error { + return Err(err); + } + let mut indices = result.completed; + let staged_count = u64::try_from(indices.len()).unwrap_or(u64::MAX); + if staged_count > self.max_blocks { + return Err(Error::TooLarge); + } + indices.sort_unstable(); + let block_ids: Vec> = indices.into_iter().map(block_id).collect(); + self.stager.commit_block_list(block_ids).await + } + + /// Await all in-flight `stage_block` calls without committing. Staged + /// but uncommitted blocks are discarded by Azure on its own timeline. + /// + /// # Errors + /// Best-effort; returns the first error seen, but does not call + /// `commit_block_list`. + pub async fn abort(mut self) -> Result<()> { + drop(self.await_uploader().await); + Ok(()) + } + + async fn await_uploader(&mut self) -> Result { + let Some(uploader) = self.uploader.take() else { + return Ok(UploaderResult { + completed: Vec::new(), + first_error: None, + }); + }; + uploader + .await + .map_err(|e| Error::Io(std::io::Error::other(e.to_string()))) + } +} + +async fn run_uploader( + stager: Arc, + mut rx: mpsc::Receiver, + semaphore: Arc, + error_slot: Arc>>, +) -> UploaderResult { + let mut in_flight: Vec>> = Vec::new(); + + while let Some(msg) = rx.recv().await { + match msg { + UploaderMsg::Stage { index, data } => { + // Acquire a permit (bounded in-flight). Held by the worker + // task until stage_block completes. + let Ok(permit) = semaphore.clone().acquire_owned().await else { + break; + }; + let stager = stager.clone(); + let id = block_id(index); + let worker = tokio::spawn(async move { + let _permit = permit; + stager + .stage_block(id, data) + .await + .map(|()| index) + .map_err(|e| (index, e)) + }); + in_flight.push(worker); + } + } + } + + // Channel closed; await all in-flight stages. + let mut completed = Vec::with_capacity(in_flight.len()); + for handle in in_flight { + match handle.await { + Ok(Ok(index)) => completed.push(index), + Ok(Err((_index, err))) => { + if let Ok(mut slot) = error_slot.lock() + && slot.is_none() + { + *slot = Some(err); + } + } + Err(join_err) => { + if let Ok(mut slot) = error_slot.lock() + && slot.is_none() + { + *slot = Some(Error::Io(std::io::Error::other(join_err.to_string()))); + } + } + } + } + + let first_error = error_slot.lock().ok().and_then(|mut s| s.take()); + + UploaderResult { + completed, + first_error, + } +} + +#[cfg(test)] +mod tests { + #![expect( + clippy::expect_used, + clippy::indexing_slicing, + clippy::similar_names, + reason = "tests assert on pre-known shapes and value counts" + )] + + use super::*; + use core::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Mutex as StdMutex; + + /// In-memory `BlockStager` used by tests. Records every staged block + /// and every commit call. Optionally fails a specific block index. + struct FakeStager { + staged: StdMutex, Bytes)>>, + commits: StdMutex>>>, + fail_index: Option, + stage_call_count: AtomicUsize, + max_concurrent_stages: AtomicUsize, + current_stages: AtomicUsize, + } + + impl FakeStager { + fn new() -> Self { + Self { + staged: StdMutex::new(Vec::new()), + commits: StdMutex::new(Vec::new()), + fail_index: None, + stage_call_count: AtomicUsize::new(0), + max_concurrent_stages: AtomicUsize::new(0), + current_stages: AtomicUsize::new(0), + } + } + + fn failing(index: u64) -> Self { + Self { + fail_index: Some(index), + ..Self::new() + } + } + + fn locked_staged(&self) -> Vec<(Vec, Bytes)> { + self.staged + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } + + fn locked_commits(&self) -> Vec>> { + self.commits + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } + } + + #[async_trait] + impl BlockStager for FakeStager { + async fn stage_block(&self, block_id: Vec, body: Bytes) -> Result<()> { + self.stage_call_count.fetch_add(1, Ordering::SeqCst); + let now = self + .current_stages + .fetch_add(1, Ordering::SeqCst) + .saturating_add(1); + self.max_concurrent_stages.fetch_max(now, Ordering::SeqCst); + + // Give other in-flight tasks a chance to overlap so the + // concurrency test can observe parallelism (or its absence). + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + let should_fail = self.fail_index.is_some_and(|target| { + if block_id.len() == 8 { + let mut buf = [0_u8; 8]; + buf.copy_from_slice(&block_id); + u64::from_be_bytes(buf) == target + } else { + false + } + }); + + self.current_stages.fetch_sub(1, Ordering::SeqCst); + if should_fail { + return Err(Error::Io(std::io::Error::other("simulated failure"))); + } + + self.staged + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push((block_id, body)); + Ok(()) + } + + async fn commit_block_list(&self, block_ids: Vec>) -> Result<()> { + self.commits + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(block_ids); + Ok(()) + } + } + + fn nz(n: usize) -> NonZeroUsize { + NonZeroUsize::new(n).expect("test constant non-zero") + } + + fn build_stream( + stager: Arc, + block_size: usize, + concurrency: usize, + ) -> BlockBlobStream { + BlockBlobStream::with_stager(stager, nz(block_size), nz(concurrency)) + } + + async fn run_write(stream: BlockBlobStream, write: F) -> (BlockBlobStream, IoResult<()>) + where + F: FnOnce(&mut dyn Write) -> IoResult<()> + Send + 'static, + { + // SyncIoBridge requires us to be off the runtime thread; spawn_blocking + // models the real usage from the binary. + tokio::task::spawn_blocking(move || { + let mut stream = stream; + let result = write(stream.writer()); + let shutdown = stream.finish_writes(); + let combined = result.and(shutdown); + (stream, combined) + }) + .await + .expect("spawn_blocking join") + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn finalize_with_zero_writes_commits_empty_block_list() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 8, 2); + + let (stream, result) = run_write(stream, |_w| Ok(())).await; + result.expect("writer shutdown succeeds"); + stream.finalize().await.expect("finalize succeeds"); + + let commits = stager.locked_commits(); + assert_eq!(commits.len(), 1, "exactly one commit"); + assert!( + commits[0].is_empty(), + "empty snapshot commits empty block list" + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn rotation_at_exact_block_size_emits_uniform_ids() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 4, 3); + + let payload: Vec = (0..12).collect(); + let (stream, result) = run_write(stream, move |w| w.write_all(&payload)).await; + result.expect("write + shutdown"); + stream.finalize().await.expect("finalize"); + + let staged = stager.locked_staged(); + assert_eq!(staged.len(), 3, "three full blocks staged"); + for entry in &staged { + assert_eq!(entry.0.len(), 8, "block ids uniform width"); + } + + let commits = stager.locked_commits(); + assert_eq!(commits.len(), 1); + let ids = &commits[0]; + let mut sorted = ids.clone(); + sorted.sort(); + assert_eq!(ids, &sorted, "committed ids are sorted ascending"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn trailing_partial_block_is_staged_on_shutdown() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 4, 3); + + let payload: Vec = (0..6).collect(); // 4 + 2 + let (stream, result) = run_write(stream, move |w| w.write_all(&payload)).await; + result.expect("write + shutdown"); + stream.finalize().await.expect("finalize"); + + let mut staged = stager.locked_staged(); + // Stages may complete out of order with concurrency>1; sort by id. + staged.sort_by(|a, b| a.0.cmp(&b.0)); + assert_eq!(staged.len(), 2); + assert_eq!(staged[0].1.len(), 4); + assert_eq!(staged[1].1.len(), 2); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn stage_failure_surfaces_and_skips_commit() { + let stager = Arc::new(FakeStager::failing(1)); + let stream = build_stream(stager.clone(), 4, 2); + + let payload: Vec = (0..12).collect(); + let (stream, _result) = run_write(stream, move |w| { + // The write may or may not succeed depending on timing; + // either way the error surfaces in finalize. + drop(w.write_all(&payload)); + Ok(()) + }) + .await; + let err = stream + .finalize() + .await + .expect_err("finalize must report stage failure"); + assert!(matches!(err, Error::Io(_)), "got: {err:?}"); + + let commits = stager.locked_commits(); + assert!(commits.is_empty(), "no commit after stage failure"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn abort_does_not_commit() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 4, 2); + + let payload: Vec = (0..8).collect(); + let (stream, result) = run_write(stream, move |w| w.write_all(&payload)).await; + result.expect("write + shutdown"); + stream.abort().await.expect("abort succeeds"); + + let commits = stager.locked_commits(); + assert!(commits.is_empty(), "abort skips commit"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn concurrency_bound_is_respected() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 1, 1); + + // 6 single-byte blocks; with concurrency=1, max in-flight must be 1. + let payload: Vec = (0..6).collect(); + let (stream, result) = run_write(stream, move |w| w.write_all(&payload)).await; + result.expect("write + shutdown"); + stream.finalize().await.expect("finalize"); + + let observed = stager.max_concurrent_stages.load(Ordering::SeqCst); + assert_eq!(observed, 1, "concurrency=1 caps in-flight at 1"); + assert_eq!(stager.stage_call_count.load(Ordering::SeqCst), 6); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn higher_concurrency_allows_overlap() { + let stager = Arc::new(FakeStager::new()); + let stream = build_stream(stager.clone(), 1, 4); + + // 8 single-byte blocks with concurrency=4: max in-flight may be up to 4. + let payload: Vec = (0..8).collect(); + let (stream, result) = run_write(stream, move |w| w.write_all(&payload)).await; + result.expect("write + shutdown"); + stream.finalize().await.expect("finalize"); + + let observed = stager.max_concurrent_stages.load(Ordering::SeqCst); + assert!(observed >= 2, "expected some overlap, got {observed}"); + assert!(observed <= 4, "must not exceed configured concurrency"); + } + + #[test] + fn block_id_round_trip() { + for i in [0_u64, 1, 100, u64::MAX] { + let id = block_id(i); + assert_eq!(id.len(), 8); + let mut buf = [0_u8; 8]; + buf.copy_from_slice(&id); + assert_eq!(u64::from_be_bytes(buf), i); + } + } +}