Contrib: S3Diff one-step 4x super-resolution#149
Open
jimburtoft wants to merge 3 commits into
Open
Conversation
S3Diff (ECCV 2024) performs degradation-guided 4x super-resolution in a single denoising step using SD-Turbo with dynamic LoRA modulation. A DEResNet encoder estimates degradation and produces per-layer modulation matrices injected between LoRA A/B weights. Uses torch_neuronx.trace() (no TP needed, model is ~2 GB). Validated on trn2.3xlarge, SDK 2.29: 0.544s/image, ~21x CPU speedup. LoRA components use --model-type=unet-inference to avoid NaN from --auto-cast=matmult on small einsum operations.
trace() approach is validated at 128x128->512x512 only. Higher resolutions produce NaN or degraded output due to BF16 accumulation in LoRA einsum at larger spatial dims. Reference torch.compile alternative for multi-resolution (1K/2K/4K) use cases.
Implements overlapping tile processing with Gaussian blending so that images larger than 512x512 HR (e.g., 256->1024, 512->2048) are handled without recompilation. All components remain compiled at fixed 512x512 tile size; larger images are split, processed per-tile, and blended. Benchmarks on trn2.3xlarge: - 128->512: 0.545s (single tile) - 256->1024: 4.8s (9 tiles) - 512->2048: 13.3s (25 tiles) Tests: 5/5 pass including new tiling tests.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds S3Diff (ECCV 2024) one-step 4x super-resolution to contrib. S3Diff uses SD-Turbo's UNet with degradation-guided dynamic LoRA modulation: a DEResNet encoder estimates input degradation and produces per-layer
[rank, rank]modulation matrices injected between LoRA A/B weights via einsum.torch_neuronx.trace()(model is ~2 GB, no tensor parallelism needed)Benchmark Results (128x128 -> 512x512)
Implementation Notes
--model-type=unet-inferenceinstead of--auto-cast=matmultbecause the small einsum modulation operations produce NaN under BF16 auto-casting.--auto-cast=matmultnormally.Files
src/modeling_s3diff.py— Full pipeline: DEResNet, LoRA forward, wrappers, S3DiffNeuronPipeline classsrc/generate_s3diff.py— CLI generation script with weight download supporttest/integration/test_model.py— 3 integration tests (smoke, SR size, timing)README.md— Full documentationTests
All 3 integration tests pass on trn2.3xlarge:
test_smoke_pipeline_loads— PASSEDtest_sr_produces_correct_size— PASSED (512x512 output, pixel std=19.9)test_warm_generation_time— PASSED (0.544s < 2s threshold)Weights
stabilityai/sd-turbozhangap/S3Diff