diff --git a/.gitignore b/.gitignore index 20631ef..4c441e6 100644 --- a/.gitignore +++ b/.gitignore @@ -20,9 +20,9 @@ __pycache__/ build/ results/ checkpoints/ -data/ develop-eggs/ dist/ +data/ARKitScenes/data downloads/ eggs/ .eggs/ @@ -188,3 +188,7 @@ glove *.ipynb *.ply 3rdparty + +# others +promptda/scripts/yolov8n.pt +data/ARKitScenes/data diff --git a/compare_inference.py b/compare_inference.py new file mode 100644 index 0000000..a89f89b --- /dev/null +++ b/compare_inference.py @@ -0,0 +1,210 @@ +""" +Comparison script: Baseline vs MLF inference. + +Runs both baseline and MLF models on the same dataset and compares metrics side-by-side. + +Usage: + python compare_inference.py --data_root data/ARKitScenes/data/upsampling \ + --encoder vitl \ + --mlf_checkpoint checkpoints/experiment/best.pth \ + --output_dir results/comparison +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +# Add project root to path +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from dataset.arkitscene import ARKitScenesDataset, collate_fn +from promptda.promptda import PromptDA +from promptda.utils.logger import Log +from training.metrics import compute_depth_metrics, aggregate_metrics + + +def parse_args(): + p = argparse.ArgumentParser(description="PromptDA Baseline vs MLF Comparison") + + # Data + p.add_argument("--data_root", type=str, default="data/ARKitScenes/data/upsampling") + p.add_argument("--split", type=str, default="Validation", choices=["Training", "Validation"]) + p.add_argument("--image_size", type=int, nargs=2, default=[196, 252]) + p.add_argument("--num_workers", type=int, default=4) + + # Model + p.add_argument("--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl"]) + p.add_argument("--pretrained_path", type=str, default=None, + help="Baseline pretrained checkpoint") + p.add_argument("--mlf_checkpoint", type=str, required=True, + help="MLF trained checkpoint") + + # Output + p.add_argument("--output_dir", type=str, default="results/comparison") + p.add_argument("--batch_size", type=int, default=4) + + return p.parse_args() + + +def load_checkpoint(ckpt_path, model, device): + """Load checkpoint weights.""" + if not os.path.exists(ckpt_path): + Log.warn(f"Checkpoint not found: {ckpt_path}") + return False + + Log.info(f"Loading checkpoint: {ckpt_path}") + state = torch.load(ckpt_path, map_location=device) + + if "model" in state: + state_dict = state["model"] + else: + state_dict = state + + model.load_state_dict(state_dict, strict=False) + return True + + +def run_inference(model, loader, device, model_name): + """Run inference and collect metrics.""" + Log.info(f"Running {model_name} inference...") + all_metrics = [] + + with torch.no_grad(): + for batch in tqdm(loader, desc=model_name, leave=False): + image = batch["image"].to(device) + depth_gt = batch["depth_gt"].to(device) + prompt = batch["prompt"].to(device) + boxes = [b.to(device) for b in batch["boxes"]] + + pred = model(image, prompt, boxes=boxes if boxes[0].numel() > 0 else None) + metrics = compute_depth_metrics(pred, depth_gt) + all_metrics.append(metrics) + + return aggregate_metrics(all_metrics) + + +def main(): + args = parse_args() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + Log.info(f"Device: {device}") + Log.info(f"Encoder: {args.encoder}") + Log.info(f"Split: {args.split}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Load dataset + Log.info(f"Loading {args.split} dataset...") + dataset = ARKitScenesDataset( + data_root=args.data_root, + split=args.split, + image_size=tuple(args.image_size), + ) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=collate_fn, + pin_memory=True, + ) + Log.info(f"Dataset size: {len(dataset)}") + + # Load baseline model + Log.info(f"Loading baseline model (no MLF)...") + baseline_model = PromptDA.from_pretrained( + pretrained_model_name_or_path=args.pretrained_path, + encoder=args.encoder, + use_mlf=False, + ).to(device).eval() + + # Load MLF model + Log.info(f"Loading MLF model...") + mlf_model = PromptDA.from_pretrained( + pretrained_model_name_or_path=args.pretrained_path, + encoder=args.encoder, + use_mlf=True, + ).to(device).eval() + + # Load MLF checkpoint + if not load_checkpoint(args.mlf_checkpoint, mlf_model, device): + Log.warn("Using random MLF weights (this will give poor results)") + + # Run inference for both models + baseline_metrics = run_inference(baseline_model, loader, device, "Baseline") + mlf_metrics = run_inference(mlf_model, loader, device, "MLF") + + # Compute improvements + improvements = {} + for key in baseline_metrics.keys(): + baseline_val = baseline_metrics[key] + mlf_val = mlf_metrics[key] + + # For metrics where lower is better (loss, error metrics) + if key in ["AbsRel", "MAE", "RMSE", "Log10", "SILog"]: + improvement = (baseline_val - mlf_val) / (baseline_val + 1e-8) * 100 + improvements[key] = improvement + # For metrics where higher is better (delta) + else: + improvement = (mlf_val - baseline_val) / (baseline_val + 1e-8) * 100 + improvements[key] = improvement + + # Print comparison + Log.info("=" * 80) + Log.info(f"Baseline vs MLF Comparison ({args.split})") + Log.info("=" * 80) + Log.info(f"{'Metric':<15} {'Baseline':<15} {'MLF':<15} {'Improvement':<15}") + Log.info("-" * 80) + + for key in sorted(baseline_metrics.keys()): + baseline_val = baseline_metrics[key] + mlf_val = mlf_metrics[key] + improvement = improvements[key] + + improvement_sign = "↓" if improvement < 0 else "↑" + log_str = f"{key:<15} {baseline_val:<15.6f} {mlf_val:<15.6f} {improvement_sign} {abs(improvement):>10.2f}%" + Log.info(log_str) + + Log.info("=" * 80) + + # Save results + results = { + "baseline": baseline_metrics, + "mlf": mlf_metrics, + "improvements": improvements, + } + + results_file = os.path.join(args.output_dir, "comparison_results.json") + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + Log.info(f"Results saved to: {results_file}") + + # Save comparison as text + txt_file = os.path.join(args.output_dir, "comparison_results.txt") + with open(txt_file, "w") as f: + f.write(f"Baseline vs MLF Comparison ({args.split})\n") + f.write("=" * 80 + "\n") + f.write(f"{'Metric':<15} {'Baseline':<15} {'MLF':<15} {'Improvement':<15}\n") + f.write("-" * 80 + "\n") + for key in sorted(baseline_metrics.keys()): + baseline_val = baseline_metrics[key] + mlf_val = mlf_metrics[key] + improvement = improvements[key] + improvement_sign = "↓" if improvement < 0 else "↑" + f.write(f"{key:<15} {baseline_val:<15.6f} {mlf_val:<15.6f} {improvement_sign} {abs(improvement):>10.2f}%\n") + f.write("=" * 80 + "\n") + Log.info(f"Comparison saved to: {txt_file}") + + return results + + +if __name__ == "__main__": + main() diff --git a/data/ARKitScenes/.gitignore b/data/ARKitScenes/.gitignore new file mode 100644 index 0000000..7ceeec5 --- /dev/null +++ b/data/ARKitScenes/.gitignore @@ -0,0 +1,6 @@ +.idea +venv +*pyc +online_prepared_data +offline_prepared_data +threedod/sample_data \ No newline at end of file diff --git a/data/ARKitScenes/CODE_OF_CONDUCT.md b/data/ARKitScenes/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..c991377 --- /dev/null +++ b/data/ARKitScenes/CODE_OF_CONDUCT.md @@ -0,0 +1,71 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, +available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) \ No newline at end of file diff --git a/data/ARKitScenes/CONTRIBUTING.md b/data/ARKitScenes/CONTRIBUTING.md new file mode 100644 index 0000000..c5364ed --- /dev/null +++ b/data/ARKitScenes/CONTRIBUTING.md @@ -0,0 +1,11 @@ +# Contribution Guide + +Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. + +While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. + +## Before you get started + +By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). + +We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). \ No newline at end of file diff --git a/data/ARKitScenes/DATA.md b/data/ARKitScenes/DATA.md new file mode 100644 index 0000000..36e7332 --- /dev/null +++ b/data/ARKitScenes/DATA.md @@ -0,0 +1,100 @@ +# Data download +This section will guide you on how to download the datasets, and explain about each file format exist in the datasets. + +## datasets +ARKitScenes includes 3 datasets, +1. `3dod` - The dataset used to train 3d object detection. The dataset includes 3 assets: low resolution RGB image, low resolution depth image and the labels (The total Size is 623.4 GB for 5047 threedod scans) +2. `upsampling` - The dataset used to train depth upsampling. The dataset includes 3 assets: high resolution RGB image, low resolution depth image and high resolution depth image +3. `raw` - This dataset includes all data available in ARKitScenes, the 3dod and depth upsampling datasets are a subset of it, +the dataset includes much more assets that are not part of 3DOD or depth upsampling. + + +## Downloading data +Each dataset has a CSV file that includes all the `visit_id`, `video_id` and `fold` available in the dataset. + +3DOD CSV path: +``` +ARKitScenes/threedod/3dod_train_val_splits.csv +``` +Upsampling CSV path: +``` +ARKitScenes/depth_upsampling/upsampling_train_val_splits.csv +``` +Raw CSV path: +``` +ARKitScenes/raw/raw_train_val_splits.csv +``` + +To download each one of the datasets, we added a python script - `download_data.py`. + +To download a specific video_id or series of video_ids, `download_data.py` expect the first argument to be the dataset name (i.e. 3dod/upsampling/raw) +the second argument the fold (i.e. Training/Validation) and video_id or series of video_ids. + +```shell script +python3 download_data.py [3dod/upsampling/raw] --split [Training/Validation] --video_id video_id1 video_id2 \ +--download_dir YOUR_DATA_FOLDER +``` +for example +```shell script +python3 download_data.py raw --split Training --video_id 47333462 \ +--download_dir /tmp/ARKitScenes/ +``` +or +```shell script +python3 download_data.py raw --split Training --video_id 47333462 \ +--download_dir /tmp/ARKitScenes/ --download_laser_scanner_point_cloud +``` +to download the laser scanner point-clouds (available only for the raw dataset) + +To download with CSV, `download_data.py` expect the first argument to be a dataset name (i.e. 3dod/upsampling/raw), +and no need for the fold, because the fold information exist in the CSV file. +```shell script +python3 download_data.py [3dod/upsampling/raw] --video_id_csv CSV_PATH \ +--download_dir YOUR_DATA_FOLDER +``` +for example +```shell script +python3 download_data.py 3dod --video_id_csv threedod/3dod_train_val_splits.csv \ +--download_dir /tmp/raw_ARKitScenes/ +``` + +Please note that for raw data, you will need to specify the type(s) of data you would like to download. +The choices are +``` +mov annotation mesh confidence highres_depth lowres_depth lowres_wide.traj lowres_wide lowres_wide_intrinsics ultrawide +ultrawide_intrinsics vga_wide vga_wide_intrinsics +``` + +for example +```shell script +python3 download_data.py raw --video_id_csv raw/raw_train_val_splits.csv --download_dir /tmp/ar_raw_all/ \ +--raw_dataset_assets mov annotation mesh confidence highres_depth lowres_depth lowres_wide.traj \ +lowres_wide lowres_wide_intrinsics ultrawide ultrawide_intrinsics vga_wide vga_wide_intrinsics +``` + +The data folder (i.e. `YOUR_DATA_DIR`) will includes two directories, `Training` and `Validation` which includes all the assets +belonging to training and validation bin respectively. + +## Dataset files formats +The dataset includes the following formats +1. `.png` - store RGB images, depth images and confidence images + - `RGB images` - regular `uint8`, 3 channel image + - `depth image` - `uint16` png format in millimeters + - `confidence` - `uint8` png format `0`-low confidence `2`-high confidence +2. `.pincam` - store the intrinsic matrix for each RGB image + - is a single-line text file, space-delimited, with the following fields: + `width` `height` `focal_length_x` `focal_length_y` `principal_point_x` `principal_point_y` +3. `.json` - store the object annotation +4. `.traj` - is a space-delimited file where each line represents a camera position at a particular timestamp + - Column 1: timestamp + - Columns 2-4: rotation (axis-angle representation in radians) + - Columns 5-7: translation (in meters) +5. `.ply` - store the mesh generated by ARKit or the point-clouds generated by the Faro laser scanner +6. `.mov` - video captured with ARKit (raw dataset only) +7. `_pose.txt` - Transformation matrix to align/register multiple FARO scans - Lines 0-2 contain the rotation matrix and line 3 the translation vector + +## Dataset structure +To deep dive into the structure of each of the datasets please go to the documentation of each one of the datasets +### [RAW](raw/README.md) +### [3DOD](threedod/README.md) +### [Depth upsampling](depth_upsampling/README.md) diff --git a/data/ARKitScenes/LICENSE b/data/ARKitScenes/LICENSE new file mode 100644 index 0000000..be59641 --- /dev/null +++ b/data/ARKitScenes/LICENSE @@ -0,0 +1,20 @@ +Copyright (C) 2021 Apple Inc. All Rights Reserved. + +IMPORTANT: This Apple software is supplied to you by Apple Inc. ("Apple") in consideration of your agreement to the following terms, and your use, installation, modification or redistribution of this Apple software constitutes acceptance of these terms. If you do not agree with these terms, please do not use, install, modify or redistribute this Apple software. + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +In consideration of your agreement to abide by the following terms, and subject to these terms, Apple grants you a personal, non-commercial, non-exclusive license, under Apple's copyrights in this original Apple software (the "Apple Software"), to use, reproduce, modify and redistribute the Apple Software, with or without modifications, in source and/or binary forms for non-commercial purposes only; provided that if you redistribute the Apple Software in its entirety and without modifications, you must retain this notice and the following text and disclaimers in all such redistributions of the Apple Software. Neither the name, trademarks, service marks or logos of Apple Inc. may be used to endorse or promote products derived from the Apple Software without specific prior written permission from Apple. Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent rights that may be infringed by your derivative works or by other works in which the Apple Software may be incorporated. + +Commercial License terms: If in every month prior to August, 2024, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, was less than 700 million monthly active users, then in consideration of your agreement to abide by the following terms, and subject to these terms, Apple grants you a personal, non-exclusive license, under Apple's copyrights in this original Apple software (the "Apple Software"), to use, reproduce, modify and redistribute the Apple Software, with or without modifications, in source and/or binary forms; provided that if you redistribute the Apple Software in its entirety and without modifications, you must retain this notice and the following text and disclaimers in all such redistributions of the Apple Software. Neither the name, trademarks, service marks or logos of Apple Inc. may be used to endorse or promote products derived from the Apple Software without specific prior written permission from Apple. Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent rights that may be infringed by your derivative works or by other works in which the Apple Software may be incorporated. If in any month prior to August, 2024, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than or equal to 700 million monthly active users in any preceding calendar month, you must request a license from Apple, which Apple may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Apple otherwise expressly grants you such rights + +The Apple Software is provided by Apple on an "AS IS" basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. + +IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +------------------------------------------------------------------------------- +[SOFTWARE DISTRIBUTED IN THIS REPOSITORY: +This software includes a number of subcomponents with separate copyright +notices and license terms - please see the file ACKNOWLEDGEMENTS.txt.] + + diff --git a/data/ARKitScenes/README.md b/data/ARKitScenes/README.md new file mode 100644 index 0000000..aa2b994 --- /dev/null +++ b/data/ARKitScenes/README.md @@ -0,0 +1,77 @@ +# ARKitScenes + +This repo accompanies the research paper, [ARKitScenes - A Diverse Real-World Dataset for 3D Indoor Scene Understanding +Using Mobile RGB-D Data](https://openreview.net/forum?id=tjZjv_qh_CE) and contains the data, scripts to visualize +and process assets, and training code described in our paper. + +![image](https://user-images.githubusercontent.com/7753049/144107932-39b010fc-6111-4b13-9c68-57dd903d78c5.png) + +![image](https://user-images.githubusercontent.com/7753049/144108052-6a1d3a67-3948-4ded-bd08-6f1572fdf97a.png) + +## Paper +[ARKitScenes - A Diverse Real-World Dataset for 3D Indoor Scene Understanding +Using Mobile RGB-D Data](https://openreview.net/forum?id=tjZjv_qh_CE) + +upon using these data or source code, please cite +```buildoutcfg +@inproceedings{ +dehghan2021arkitscenes, +title={{ARK}itScenes - A Diverse Real-World Dataset for 3D Indoor Scene Understanding Using Mobile {RGB}-D Data}, +author={Gilad Baruch and Zhuoyuan Chen and Afshin Dehghan and Tal Dimry and Yuri Feigin and Peter Fu and Thomas Gebauer and Brandon Joffe and Daniel Kurz and Arik Schwartz and Elad Shulman}, +booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 1)}, +year={2021}, +url={https://openreview.net/forum?id=tjZjv_qh_CE} +} +``` + +## Overview +ARKitScenes is not only the first RGB-D dataset that is captured with now widely available depth sensor, but also is the +largest indoor scene understanding data ever collected. In addition to the raw and processed data, ARKitScenes includes +high resolution depth maps captured using a stationary laser scanner, as well as manually labeled 3D oriented bounding +boxes for a large taxonomy of furniture. We further provide helper scripts for two downstream tasks: +3D object detection and RGB-D guided upsampling. We hope that our dataset can help push the boundaries of +existing state-of-the-art methods and introduce new challenges that better represent real world scenarios. + +## Key features +• ARKitScenes is the first RGB-D dataset captured with the widely available +Apple LiDAR scanner. Along with the raw data we provide the camera pose and surface +reconstruction for each scene. + +• ARKitScenes is the largest indoor 3D dataset consisting of 5,047 captures of 1,661 unique +scenes. + +• We provide high quality ground truth of (a) registered RGB-D frames and (b) oriented +bounding boxes of room defining objects. + +Below is an overview of RGB-D datasets and their ground truth assets compared with ARKitScenes. +HR and LR represent High Resolution and Low Resolution respectively, and are available for a subset of 2,257 captures of 841 unique +scenes. + +![image](https://user-images.githubusercontent.com/7753049/144108117-b789a5be-cc08-44f0-a76c-f1549c59825e.png) + + +## Data collection + +In the figure below, we provide (a) illustration of iPad Pro scanning set up. (b) mesh overlay to assist data collection with iPad Pro. (c) example of one of the scan patterns captured with the iPad pro, the red markers show the chosen locations of the stationary laser scanner in that room. + +![image](https://user-images.githubusercontent.com/7753049/144108161-0ae7ba6a-305f-4a22-93b1-0b2d1e78154e.png) + +## Data download + +To download the data please follow the [data](DATA.md) documentation + +## Tasks + +Here we provide the two tasks mentioned in our paper, namely, 3D Object Detection (3DOD) and depth upsampling. + +### [3DOD](threedod/README.md) + +### [Depth upsampling](depth_upsampling/README.md) + +## License + +Please refer to the LICENSE file for detailed information on using the dataset. + +For any additional inquiries about the license, please contact ARKitScenes-license@group.apple.com. + +If you have other questions, feel free to open an issue in the repository or contact ARKitScenes@group.apple.com. diff --git a/data/ARKitScenes/__init__.py b/data/ARKitScenes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/README.md b/data/ARKitScenes/depth_upsampling/README.md new file mode 100644 index 0000000..c4dc40b --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/README.md @@ -0,0 +1,77 @@ +# Depth upsampling + +## Data download +To download the data please follow the [data](../DATA.md) documentation + +## Data organization and format of input data +The dataset includes 4 types of assets and metadata: +1. `color` - the RGB images (1920x1440) +2. `highres_depth` - the ground-truth depth image projected from the mesh generated by Faro’s laser scanners (1920x1440) +3. `lowres_depth` - the depth image acquired by AppleDepth Lidar (256x192) +4. `confidence` - the confidence of the AppleDepth depth image (256x192) +5. `metadata.csv` - meta data per video (i.e. sky direction - (up/down/left/right)) +6. `val_attributes.csv` - attributes per sample (i.e. `transparent_or_reflective` - if True, the image includes a transparent or reflective objects). Manually annotated and only relevant for the Validation bin. + +[Data](../DATA.md) documentation describe the format of each one of the asset. + + +``` +ARKitScenes/depth_upsampling/ +├── Training # training bin assets folder +│ ├── 41069021 # video_id assets folder +│ │ ├── color # color assets folder +│ │ │ ├── 41069021_305.244.png # color frames +│ │ │ ├── 41069021_307.343.png +│ │ │ ├── 41069021_309.742.png +│ │ │ └── ... +│ │ ├── highres_depth # highres_depth folder +│ │ │ ├── 41069021_305.244.png # highres_depth frames +│ │ │ ├── 41069021_307.343.png +│ │ │ ├── 41069021_309.742.png +│ │ │ └── ... +│ │ ├── lowres_depth # lowres_depth folder +│ │ │ ├── 41069021_305.244.png # lowres_depth frames +│ │ │ ├── 41069021_307.343.png +│ │ │ ├── 41069021_309.742.png +│ │ │ └── ... +│ │ └── confidence # confidence folder +│ │ ├── 41069021_305.244.png # confidence frames +│ │ ├── 41069021_307.343.png +│ │ ├── 41069021_309.742.png +│ │ └── ... +│ ├── +│ └── ... +└── Validation # validation bin assets folder + └── ... +``` + +## Creating a python environment +The packages required for training depth upsampling are listed in the file `requirements.txt`, +to install them run + +```shell script +cd depth_upsampling +pip install -r requirements.txt +``` + +## Visualizing depth upsampling assets +To view upsampling assets you can use the following script: +(note that first you need to [download](#Data download) the dataset) +```shell script +python3 depth_upsampling/sample_vis.py YOUR_DATA_DIR/ARKitScenes --split [train/val] --sample_id SAMPLE_ID +``` +for example to visualize a sample from validation bin you can run: +```shell script +python3 depth_upsampling/sample_vis.py YOUR_DATA_DIR/ARKitScenes --split val --sample_id 41069021_305.244.png +``` +## training depth upsampling + +You can train the upsampling networks by running +```shell script +python train.py --network [MSG/MSPF] --upsample_factor [2/4/8] +``` +The training script will print to the screen the metrics once every 5k iterations. +To view the results on tensorboard +you can add a tensorboard port parameter `--tbp some_port_number` to the `train.py` input parameters. +This will automatically open a tensorboard process on a subprocess. + \ No newline at end of file diff --git a/data/ARKitScenes/depth_upsampling/__init__.py b/data/ARKitScenes/depth_upsampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/data_utils.py b/data/ARKitScenes/depth_upsampling/data_utils.py new file mode 100644 index 0000000..82b5194 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/data_utils.py @@ -0,0 +1,32 @@ +import numpy as np +import torch + + +def expand_channel_dim(img): + """ + expand image dimension to add a channel dimension + """ + return np.expand_dims(img, 0) + + +def image_hwc_to_chw(img): + """ + transpose the image from height, width, channel -> channel, height, width + (pytorch format) + """ + return img.transpose((2, 0, 1)) + + +def image_chw_to_hwc(img): + """ + revert image_hwc_to_chw function + """ + return img.transpose((1, 2, 0)) + + +def batch_to_cuda(batch): + if torch.cuda.is_available(): + for k in batch: + if k != 'identifier': + batch[k] = batch[k].cuda(non_blocking=True) + return batch diff --git a/data/ARKitScenes/depth_upsampling/dataset.py b/data/ARKitScenes/depth_upsampling/dataset.py new file mode 100644 index 0000000..1bbb9f6 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/dataset.py @@ -0,0 +1,130 @@ +import os +from glob import glob +from typing import Callable, Optional + +import cv2 +import numpy as np +import pandas as pd +from torch.utils.data import Dataset + +from . import dataset_keys +from .data_utils import image_hwc_to_chw, expand_channel_dim + +META_DATA_CSV_FILE = 'metadata.csv' +WIDE = 'wide' +LOW_RES = (384, 512) +HIGH_RES = (1440, 1920) +MILLIMETER_TO_METER = 1000 + + +class ARKitScenesDataset(Dataset): + """`ARKitScenes Dataset class. + + Args: + root (string): Root directory of dataset where directory + exists or will be saved to if download is set to True. + transform (callable, optional): A function that takes in a sample (dict) + and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and + puts on the root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__( + self, + root: str, + split: str = 'train', + transform: Optional[Callable] = None, + upsample_factor: Optional[int] = None, + ) -> None: + + super(ARKitScenesDataset, self).__init__() + self.root = os.path.expanduser(root) + self.split = split + self.transform = transform + self.upsample_factor = upsample_factor + + if self.upsample_factor is not None and self.upsample_factor not in (2, 4, 8): + raise ValueError(f'rgb_factor must to be one of (2,4,8) but got {self.upsample_factor}') + if split == 'train': + self.split_folder = 'Training' + elif split == 'val': + self.split_folder = 'Validation' + else: + raise Exception(f'split must to be one of (train, val), got ={split}') + self.dataset_folder = os.path.join(self.root, self.split_folder) + if self.upsample_factor is None: + self.low_res = LOW_RES + self.high_res = HIGH_RES + else: + if self.upsample_factor in (2, 4): + self.low_res = LOW_RES + self.high_res = [i * self.upsample_factor for i in LOW_RES] + elif self.upsample_factor == 8: + self.high_res = HIGH_RES + self.low_res = [int(i / self.upsample_factor) for i in HIGH_RES] + else: + raise Exception(f'Can\'t load dataset with upsample_factor = {self.upsample_factor}') + + self.samples = [] # videos_id, sample_id, sky_direction + self.meta_data = pd.read_csv(os.path.join(os.path.dirname(self.dataset_folder), META_DATA_CSV_FILE)) + self.meta_data = self.meta_data[self.meta_data['fold'] == self.split_folder] + for video_id, sky_direction in zip(self.meta_data['video_id'], self.meta_data['sky_direction']): + video_folder = os.path.join(self.dataset_folder, str(video_id)) + color_files = glob(os.path.join(video_folder, WIDE, '*.png')) + self.samples.extend([[str(video_id), str(os.path.basename(file)), sky_direction] + for file in color_files]) + + @staticmethod + def rotate_image(img, direction): + if direction == 'Up': + pass + elif direction == 'Left': + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + elif direction == 'Right': + img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif direction == 'Down': + img = cv2.rotate(img, cv2.ROTATE_180) + else: + raise Exception(f'No such direction (={direction}) rotation') + return img + + @staticmethod + def load_image(path, shape, is_depth, sky_direction): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.shape[:2] != shape: + img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_NEAREST if is_depth else cv2.INTER_LINEAR) + img = ARKitScenesDataset.rotate_image(img, sky_direction) + if is_depth: + img = expand_channel_dim(np.asarray(img / MILLIMETER_TO_METER, np.float32)) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = image_hwc_to_chw(np.asarray(img, np.float32)) + return img + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (identifier, color, highres_depth, lowres_depth). + """ + video_id, sample_id, direction = self.samples[index] + sample = dict() + sample[dataset_keys.IDENTIFIER] = str(sample_id) + + rgb_file = os.path.join(self.dataset_folder, video_id, WIDE, sample_id) + depth_file = os.path.join(self.dataset_folder, video_id, 'highres_depth', sample_id) + apple_depth_file = os.path.join(self.dataset_folder, video_id, 'lowres_depth', sample_id) + + sample[dataset_keys.COLOR_IMG] = self.load_image(rgb_file, self.high_res, False, direction) + sample[dataset_keys.HIGH_RES_DEPTH_IMG] = self.load_image(depth_file, self.high_res, True, direction) + sample[dataset_keys.LOW_RES_DEPTH_IMG] = self.load_image(apple_depth_file, self.low_res, True, direction) + + if self.transform is not None: + sample[dataset_keys.COLOR_IMG] = self.transform(sample[dataset_keys.COLOR_IMG]) + return sample + + def __len__(self) -> int: + return len(self.samples) diff --git a/data/ARKitScenes/depth_upsampling/dataset_keys.py b/data/ARKitScenes/depth_upsampling/dataset_keys.py new file mode 100644 index 0000000..16f87ec --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/dataset_keys.py @@ -0,0 +1,6 @@ +IDENTIFIER = 'identifier' +COLOR_IMG = 'color_img' +HIGH_RES_DEPTH_IMG = 'high_res_depth_img' +LOW_RES_DEPTH_IMG = 'low_res_depth_img' +PREDICTION_DEPTH_IMG = 'prediction_img' +VALID_MASK_IMG = 'valid_mask_img' diff --git a/data/ARKitScenes/depth_upsampling/image_utils.py b/data/ARKitScenes/depth_upsampling/image_utils.py new file mode 100644 index 0000000..a4e3661 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/image_utils.py @@ -0,0 +1,40 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np + + +def create_montage_image(image_list, image_shape=(640, 480), grid_shape=(3, 1)) -> np.ndarray: + height, width = image_shape[1], image_shape[0] + montage = np.zeros((image_shape[1] * grid_shape[1], image_shape[0] * grid_shape[0], 3), dtype="uint8") + + x_shift = 0 + y_shift = 0 + + for n in range(len(image_list)): + image = cv2.resize(image_list[n], (width, height), interpolation=cv2.INTER_NEAREST) + montage[y_shift * height: (y_shift + 1) * height, x_shift * width: (x_shift + 1) * width] = image + x_shift += 1 + + if x_shift % (grid_shape[0]) == 0 and x_shift > 0: + y_shift += 1 + x_shift = 0 + + return montage + + +def colorize(image, vmin=None, vmax=None, cmap='turbo'): + + vmin = image.min() if vmin is None else vmin + vmax = image.max() if vmax is None else vmax + + if vmin != vmax: + image = (image - vmin) / (vmax - vmin) + else: + image = image * 0. + + cmapper = plt.cm.get_cmap(cmap) + image = cmapper(image, bytes=True) + + img = image[:, :, :3] + + return img diff --git a/data/ARKitScenes/depth_upsampling/logs/__init__.py b/data/ARKitScenes/depth_upsampling/logs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/logs/eval.py b/data/ARKitScenes/depth_upsampling/logs/eval.py new file mode 100644 index 0000000..39842f3 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/logs/eval.py @@ -0,0 +1,71 @@ +from collections import defaultdict + +import numpy as np +import torch + +import dataset_keys +import image_utils +from data_utils import image_hwc_to_chw, image_chw_to_hwc, batch_to_cuda + +MAX_TENSORBOARD_FRAMES = 10 +VMIN = 0 +VMAX = 7 + + +def compute_errors(gt, pred, valid_mask): + l1 = np.abs((gt - pred) * valid_mask) + l2 = l1 ** 2 + + image_dim = (1, 2) + denominator = np.count_nonzero(valid_mask, axis=image_dim) + l1_mean = np.sum(l1, image_dim) / denominator + rmse = np.sqrt(np.sum(l2, image_dim) / denominator) + + return dict(L1=l1_mean, RMSE=rmse) + + +def eval_log(step, model, dataloader, tensorboard_writer): + model.eval() + with torch.no_grad(): + metrics = defaultdict(list) + images_added_to_tensorboard = 0 + total_samples = 0 + for i, input_batch in enumerate(dataloader): + input_batch = batch_to_cuda(input_batch) + output_batch = model(input_batch) + rgb = input_batch[dataset_keys.COLOR_IMG].cpu().numpy() + gt_depth = input_batch[dataset_keys.HIGH_RES_DEPTH_IMG].cpu().numpy().squeeze(1) + depth_lowres = input_batch[dataset_keys.LOW_RES_DEPTH_IMG].cpu().numpy().squeeze(1) + valid_mask = input_batch[dataset_keys.VALID_MASK_IMG].cpu().numpy().squeeze(1) + pred_depth = output_batch[dataset_keys.PREDICTION_DEPTH_IMG].cpu().numpy().squeeze(1) + + batch_size = rgb.shape[0] + total_samples += batch_size + batch_metrics = compute_errors(gt_depth, pred_depth, valid_mask) + for key in batch_metrics: + metrics[key].append(batch_metrics[key]) + + j = 0 + while j < batch_size and images_added_to_tensorboard <= MAX_TENSORBOARD_FRAMES: + identifier = input_batch[dataset_keys.IDENTIFIER][j] + image_list = [image_chw_to_hwc(rgb[j]), + image_utils.colorize(gt_depth[j], VMIN, VMAX), + image_utils.colorize(depth_lowres[j], VMIN, VMAX), + image_utils.colorize(pred_depth[j], VMIN, VMAX)] + + h, w = pred_depth[j].shape[:2] + montage = image_utils.create_montage_image(image_list, (w, h), grid_shape=(4, 1)) + tensorboard_writer.add_image(f'{identifier}', image_hwc_to_chw(montage / 255), + step) + j += 1 + images_added_to_tensorboard += 1 + + print(f'validation metrics') + print(("{:>7}, " * len(metrics)).format(*metrics.keys())) + for key in metrics: + metrics[key] = np.concatenate(metrics[key]) + metric = np.sum(metrics[key]) / total_samples + print('{:7.3f}, '.format(metric), end='') + tensorboard_writer.add_scalar(key, metric, step) + print() + model.train() diff --git a/data/ARKitScenes/depth_upsampling/logs/train.py b/data/ARKitScenes/depth_upsampling/logs/train.py new file mode 100644 index 0000000..06a470b --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/logs/train.py @@ -0,0 +1,35 @@ +import dataset_keys +import image_utils +from data_utils import image_hwc_to_chw, image_chw_to_hwc + +MAX_TENSORBOARD_IMAGES = 6 + + +def train_log(step, input_batch, output_batch, tensorboard_writer, **kwargs): + if step % 100 == 0: + loss = kwargs['loss'].detach().cpu().numpy() + current_lr = kwargs['current_lr'] + print('step={}, loss: {:.12f}'.format(step, loss)) + if tensorboard_writer is not None: + tensorboard_writer.add_scalar('Training/loss', loss, step) + tensorboard_writer.add_scalar('Training/learning_rate', current_lr, step) + + if step % 2000 == 0 and tensorboard_writer is not None: + rgb = input_batch[dataset_keys.COLOR_IMG].detach().cpu().numpy() + gt_depth = input_batch[dataset_keys.HIGH_RES_DEPTH_IMG].detach().cpu().numpy().squeeze(1) + depth_lowres = input_batch[dataset_keys.LOW_RES_DEPTH_IMG].detach().cpu().numpy().squeeze(1) + valid_mask = input_batch[dataset_keys.VALID_MASK_IMG].detach().cpu().numpy().squeeze(1) + pred_depth = output_batch[dataset_keys.PREDICTION_DEPTH_IMG].detach().cpu().numpy().squeeze(1) + for i in range(min(rgb.shape[0], MAX_TENSORBOARD_IMAGES)): + vmin = 0 + vmax = gt_depth.max() + + image_list = [image_chw_to_hwc(rgb[i]), + image_utils.colorize(gt_depth[i], vmin, vmax), + image_utils.colorize(depth_lowres[i], vmin, vmax), + image_utils.colorize(pred_depth[i], vmin, vmax), + image_utils.colorize(valid_mask[i], 0, 1)] + + h, w = pred_depth[i].shape[:2] + montage = image_utils.create_montage_image(image_list, (w, h), grid_shape=(5, 1)) + tensorboard_writer.add_image(f'Training/{i}', image_hwc_to_chw(montage / 255), step) diff --git a/data/ARKitScenes/depth_upsampling/losses/__init__.py b/data/ARKitScenes/depth_upsampling/losses/__init__.py new file mode 100644 index 0000000..616cdcd --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/losses/__init__.py @@ -0,0 +1,21 @@ +from .gradient_loss import gradient_loss +from .l1_loss import l1_loss +from .rmse import rmse_loss + + +def mspf_loss(output_batch, input_batch): + return l1_loss(output_batch, input_batch) + 2 * gradient_loss(output_batch, input_batch) + + +def msg_loss(output_batch, input_batch): + return rmse_loss(output_batch, input_batch) + + +def get_loss(network): + if network == 'MSG': + loss = msg_loss + elif network == 'MSPF': + loss = mspf_loss + else: + raise ValueError(f'No such network ({network})') + return loss diff --git a/data/ARKitScenes/depth_upsampling/losses/gradient_loss.py b/data/ARKitScenes/depth_upsampling/losses/gradient_loss.py new file mode 100644 index 0000000..2a374d9 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/losses/gradient_loss.py @@ -0,0 +1,113 @@ +import torch + +import dataset_keys + + +def div_by_mask_sum(loss: torch.Tensor, mask_sum: torch.Tensor): + return loss / torch.max(mask_sum, torch.ones_like(mask_sum)) + + +class SafeTorchLog(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + + input_abs = torch.abs(input) + 1e-9 + ctx.save_for_backward(input_abs) + + return torch.log(input_abs) + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + (input_abs,) = ctx.saved_tensors + grad_input = grad_output.clone() + + return grad_input * (1.0 / input_abs) / 2.302585093 # ln(10) + + +safe_torch_log = SafeTorchLog.apply + + +def create_gradient_log_loss(log_prediction_d, mask, log_gt): + + # compute log difference + log_d_diff = log_prediction_d - log_gt + log_d_diff = torch.mul(log_d_diff, mask) + + # compute vertical gradient + v_gradient = torch.abs(log_d_diff[:, :, 2:, :] - log_d_diff[:, :, :-2, :]) + v_mask = torch.mul(mask[:, :, 2:, :], mask[:, :, :-2, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + # compute horizontal gradient + h_gradient = torch.abs(log_d_diff[:, :, :, 2:] - log_d_diff[:, :, :, :-2]) + h_mask = torch.mul(mask[:, :, :, 2:], mask[:, :, :, :-2]) + h_gradient = torch.mul(h_gradient, h_mask) + + # sum up gradients + grad_loss = torch.sum(h_gradient, dim=[1, 2, 3]) + torch.sum(v_gradient, dim=[1, 2, 3]) + num_valid_pixels = torch.sum(mask, dim=[1, 2, 3]) + grad_loss = div_by_mask_sum(grad_loss, num_valid_pixels) + + return grad_loss + + +def create_gradient_log_loss_4_scales(log_prediction, log_ground_truth, mask): + log_prediction_d = log_prediction + log_gt = log_ground_truth + mask = mask + + log_prediction_d_scale_1 = log_prediction_d[:, :, ::2, ::2] + log_prediction_d_scale_2 = log_prediction_d_scale_1[:, :, ::2, ::2] + log_prediction_d_scale_3 = log_prediction_d_scale_2[:, :, ::2, ::2] + + mask_scale_1 = mask[:, :, ::2, ::2] + mask_scale_2 = mask_scale_1[:, :, ::2, ::2] + mask_scale_3 = mask_scale_2[:, :, ::2, ::2] + + log_gt_scale_1 = log_gt[:, :, ::2, ::2] + log_gt_scale_2 = log_gt_scale_1[:, :, ::2, ::2] + log_gt_scale_3 = log_gt_scale_2[:, :, ::2, ::2] + + gradient_loss_scale_0 = create_gradient_log_loss(log_prediction_d, mask, log_gt) + + gradient_loss_scale_1 = create_gradient_log_loss( + log_prediction_d_scale_1, mask_scale_1, log_gt_scale_1 + ) + + gradient_loss_scale_2 = create_gradient_log_loss( + log_prediction_d_scale_2, mask_scale_2, log_gt_scale_2 + ) + + gradient_loss_scale_3 = create_gradient_log_loss( + log_prediction_d_scale_3, mask_scale_3, log_gt_scale_3 + ) + + gradient_loss_4_scales = ( + gradient_loss_scale_0 + gradient_loss_scale_1 + gradient_loss_scale_2 + gradient_loss_scale_3 + ) + + return gradient_loss_4_scales + + +def gradient_loss(outputs, inputs): + valid_mask = inputs[dataset_keys.VALID_MASK_IMG] + gt_depth = inputs[dataset_keys.HIGH_RES_DEPTH_IMG] + prediction = outputs[dataset_keys.PREDICTION_DEPTH_IMG] + + log_prediction = safe_torch_log(prediction) + log_gt = safe_torch_log(gt_depth) + loss = create_gradient_log_loss_4_scales(log_prediction, log_gt, valid_mask) + loss = torch.mean(loss) + return loss diff --git a/data/ARKitScenes/depth_upsampling/losses/l1_loss.py b/data/ARKitScenes/depth_upsampling/losses/l1_loss.py new file mode 100644 index 0000000..61e9f92 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/losses/l1_loss.py @@ -0,0 +1,17 @@ +import torch + +import dataset_keys + +eps = 1e-6 + + +def l1_loss(outputs, inputs): + valid_mask = inputs[dataset_keys.VALID_MASK_IMG] + gt_depth = inputs[dataset_keys.HIGH_RES_DEPTH_IMG] + prediction = outputs[dataset_keys.PREDICTION_DEPTH_IMG] + + error_image = torch.abs(prediction - gt_depth) * valid_mask + sum_loss = torch.sum(error_image, dim=[1, 2, 3]) + num_valid_pixels = torch.sum(valid_mask, dim=[1, 2, 3]) + loss = sum_loss / torch.max(num_valid_pixels, torch.ones_like(num_valid_pixels) * eps) + return torch.mean(loss) diff --git a/data/ARKitScenes/depth_upsampling/losses/rmse.py b/data/ARKitScenes/depth_upsampling/losses/rmse.py new file mode 100644 index 0000000..7bba109 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/losses/rmse.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + +import dataset_keys + + +def rmse_loss(outputs, inputs): + valid_mask = inputs[dataset_keys.VALID_MASK_IMG] + gt_depth = inputs[dataset_keys.HIGH_RES_DEPTH_IMG] + prediction = outputs[dataset_keys.PREDICTION_DEPTH_IMG] + loss = F.mse_loss(prediction[valid_mask], gt_depth[valid_mask]) + loss = torch.sqrt(loss) + return loss diff --git a/data/ARKitScenes/depth_upsampling/models/__init__.py b/data/ARKitScenes/depth_upsampling/models/__init__.py new file mode 100644 index 0000000..df9979f --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/__init__.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +from torch import nn + +from .msg.msg import MSGNet +from .mspf.mspf import MSPF + + +def weights_init_xavier(m): + if isinstance(m, nn.Conv2d): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + +def get_network(network, upsampling_factor): + # Create model + if network == 'MSG': + model = MSGNet(upsampling_factor) + elif network == 'MSPF': + model = MSPF(upsampling_factor) + model.decoder.apply(weights_init_xavier) + else: + raise ValueError(f'No such network ({network})') + + num_params = sum([np.prod(p.size()) for p in model.parameters()]) + print("Total number of parameters: {}".format(num_params)) + + num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad]) + print("Total number of learning parameters: {}".format(num_params_update)) + return model diff --git a/data/ARKitScenes/depth_upsampling/models/msg/__init__.py b/data/ARKitScenes/depth_upsampling/models/msg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/models/msg/blocks.py b/data/ARKitScenes/depth_upsampling/models/msg/blocks.py new file mode 100644 index 0000000..5366acd --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/msg/blocks.py @@ -0,0 +1,26 @@ +import torch.nn as nn + + +class ConvPReLu(nn.Module): + def __init__(self, in_channels, out_channels, kernel=5, stride=1, padding=2): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding) + self.activation = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.activation(x) + return x + + +class DeconvPReLu(nn.Module): + def __init__(self, in_channels, out_channels, kernel=3, stride=2, padding=1): + super().__init__() + self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel, + stride=stride, padding=padding, output_padding=stride - 1) + self.activation = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.activation(x) + return x diff --git a/data/ARKitScenes/depth_upsampling/models/msg/msg.py b/data/ARKitScenes/depth_upsampling/models/msg/msg.py new file mode 100644 index 0000000..51bb66c --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/msg/msg.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .blocks import ConvPReLu, DeconvPReLu +import dataset_keys + + +class MSGNet(nn.Module): + """ + Inspired by: Depth Map Super-Resolution by Deep Multi-Scale Guidance + http://personal.ie.cuhk.edu.hk/~ccloy/files/eccv_2016_depth.pdf + """ + def __init__(self, upsampling_factor): + super().__init__() + # initialize indexes for layers + self.upsampling_factor = upsampling_factor + m = int(np.log2(upsampling_factor)) + + # RGB-branch + self.rgb_encoder1 = nn.Sequential(ConvPReLu(3, 49, 7, stride=1, padding=3), + ConvPReLu(49, 32)) + self.rgb_encoder_blocks = nn.ModuleList() + for i in range(m-1): + self.rgb_encoder_blocks.append(nn.Sequential(ConvPReLu(32, 32), + nn.MaxPool2d(3, 2, padding=1))) + + # D-branch + self.depth_decoder1 = nn.Sequential(ConvPReLu(1, 64, 5, stride=1, padding=2), + DeconvPReLu(64, 32, 5, stride=2, padding=2)) + self.depth_decoder_blocks = nn.ModuleList() + for i in range(m-1): + self.depth_decoder_blocks.append(nn.Sequential(ConvPReLu(64, 32, 5, stride=1, padding=2), + ConvPReLu(32, 32, 5, stride=1, padding=2), + DeconvPReLu(32, 32, 5, stride=2, padding=2))) + + self.depth_decoder_n = nn.Sequential(ConvPReLu(64, 32, 5, stride=1, padding=2), + ConvPReLu(32, 32, 5, stride=1, padding=2), + ConvPReLu(32, 32, 5, stride=1, padding=2), + ConvPReLu(32, 1, 5, stride=1, padding=2)) + + def forward(self, batch): + rgb_img = batch[dataset_keys.COLOR_IMG] / 255 + low_res_depth = batch[dataset_keys.LOW_RES_DEPTH_IMG] + min_d = low_res_depth.amin((1, 2, 3), keepdim=True) + max_d = low_res_depth.amax((1, 2, 3), keepdim=True) + low_res_depth_norm = (low_res_depth - min_d) / ((max_d - min_d) + 1e-8) + low_res_upsampled = F.interpolate(low_res_depth_norm, rgb_img.shape[2:], mode='bicubic') + + rgb_features = [self.rgb_encoder1(rgb_img), ] + for block in self.rgb_encoder_blocks: + rgb_features.append(block(rgb_features[-1])) + + rec = self.depth_decoder1(low_res_depth_norm) + for i, block in enumerate(self.depth_decoder_blocks): + rec = torch.cat((rec, rgb_features[-(i + 1)]), 1) + rec = block(rec) + rec = torch.cat((rec, rgb_features[0]), 1) + rec = self.depth_decoder_n(rec) + + output = (low_res_upsampled + rec) * (max_d - min_d) + min_d + return {dataset_keys.PREDICTION_DEPTH_IMG: output} diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/MultiScaleDepthSR.py b/data/ARKitScenes/depth_upsampling/models/mspf/MultiScaleDepthSR.py new file mode 100644 index 0000000..ac565d4 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/MultiScaleDepthSR.py @@ -0,0 +1,112 @@ +from typing import List + +import torch +from torch import nn + +from models.mspf.MultiscaleConvDepthEncoder import MultiscaleConvDepthEncoder +from models.mspf.blocks.multi_scale_depth import Upsample2D, Conv2D + +""" +- create conv for matching output channels +- skip convs to reduce channels +""" + + +class MultiscaleDepthDecoder(nn.Module): + """ + Inspired by: Multi-Scale Progressive Fusion Learning for Depth Map Super-Resolution + https://arxiv.org/pdf/2011.11865v1.pdf + """ + def __init__( + self, + input_channels: List[int], + output_channels: List[int], + upsample_factor: int + ): + super().__init__() + + activation = "relu" + batch_norm = None + self.scales = ["x32", "x16", "x8", "x4", "x2", "x1"] + + self.depth_encoder = MultiscaleConvDepthEncoder(upsample_factor) + depth_output_channels = self.depth_encoder.output_channels[::-1] + rgb_output_channels = input_channels + + self.upsample_blocks = {} + for i in range(len(self.scales)-1): + ch_input = rgb_output_channels[i] + \ + depth_output_channels[i] + \ + (output_channels[i-1] if i > 0 else 0) + conv_layers = nn.Sequential( + Conv2D( + ch_input, + output_channels[i], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + ), + Conv2D( + output_channels[i], + output_channels[i], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + )) + + upsample = Upsample2D( + output_channels[i], + output_channels[i], + ) + setattr(self, f"conv_layers_{self.scales[i]}", conv_layers) + setattr(self, f"upsample_{self.scales[i]}", upsample) + self.upsample_blocks[self.scales[i]] = (conv_layers, upsample) + + ch_input = 3 + \ + depth_output_channels[-1] + \ + (output_channels[i - 1] if i > 0 else 0) + + conv_layers = nn.Sequential( + Conv2D( + ch_input, + output_channels[i], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + ), + Conv2D( + output_channels[i], + 1, + kernel_size=3, + activation=None, + padding=1, + batch_norm=batch_norm, + )) + setattr(self, f"conv_layers_x1", conv_layers) + self.upsample_blocks['x1'] = (conv_layers, None) + + def forward( + self, depth: dict, rgb_skip_connections: dict + ): + + depth_skip_connections = self.depth_encoder(depth) + + fusion_features = None + for scale, (conv_layers, upsample) in self.upsample_blocks.items(): + + if fusion_features is None: + fusion_features = torch.cat((rgb_skip_connections[scale], depth_skip_connections[scale]), 1) + else: + fusion_features = torch.cat((rgb_skip_connections[scale], depth_skip_connections[scale], fusion_features), 1) + + fusion_features = conv_layers(fusion_features) + + if upsample is not None: + fusion_features = upsample(fusion_features) + + depth = fusion_features + + return depth diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/MultiscaleConvDepthEncoder.py b/data/ARKitScenes/depth_upsampling/models/mspf/MultiscaleConvDepthEncoder.py new file mode 100644 index 0000000..4c21f77 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/MultiscaleConvDepthEncoder.py @@ -0,0 +1,79 @@ +import torch.nn.functional as torch_nn_func +from torch import nn + +from models.mspf.blocks.multi_scale_depth import Conv2D + + +class MultiscaleConvDepthEncoder(nn.Module): + """ + Inspired by: Multi-Scale Progressive Fusion Learning for Depth Map Super-Resolution + https://arxiv.org/pdf/2011.11865v1.pdf + """ + def __init__(self, upsample_factor): + super().__init__() + self.scale = int(upsample_factor) + print("self.scale", self.scale) + activation = "relu" + batch_norm = None + self.output_channels = [16, 32, 32, 64, 64, 128] + + self.conv_layers1 = nn.Sequential( + Conv2D( + 1, + self.output_channels[0], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + ), + Conv2D( + self.output_channels[0], + self.output_channels[0], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + )) + + self.encoder_conv_blocks = [] + + for i in range(1, 6): + conv_block = nn.Sequential( + Conv2D( + self.output_channels[i-1], + self.output_channels[i], + kernel_size=3, + activation=activation, + padding=1, + batch_norm=batch_norm, + ), + Conv2D( + self.output_channels[i], + self.output_channels[i], + kernel_size=2, + activation=activation, + stride=2, + padding=0, + batch_norm=batch_norm, + )) + + setattr(self, f"conv_block_{i}", conv_block) + self.encoder_conv_blocks.append(conv_block) + + def forward(self, depth): + + depth = torch_nn_func.interpolate(depth, scale_factor=self.scale, mode='bicubic', align_corners=True) + features = self.conv_layers1(depth) + + skip_connections = {} + skip_connections["x1"] = features + + for i in range(5): + features = self.encoder_conv_blocks[i](features) + skip_connections[f"x{(2)**(i+1)}"] = features + + return skip_connections + + + + diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/__init__.py b/data/ARKitScenes/depth_upsampling/models/mspf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/blocks/__init__.py b/data/ARKitScenes/depth_upsampling/models/mspf/blocks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/blocks/dense_net.py b/data/ARKitScenes/depth_upsampling/models/mspf/blocks/dense_net.py new file mode 100644 index 0000000..4789470 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/blocks/dense_net.py @@ -0,0 +1,117 @@ +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class _DenseLayer(nn.Module): + def __init__( + self, + num_input_features: int, + growth_rate: int, + bn_size: int, + drop_rate: float, + memory_efficient: bool = False + ) -> None: + super(_DenseLayer, self).__init__() + self.norm1: nn.BatchNorm2d + self.add_module('norm1', nn.BatchNorm2d(num_input_features)) + self.relu1: nn.ReLU + self.add_module('relu1', nn.ReLU(inplace=True)) + self.conv1: nn.Conv2d + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * + growth_rate, kernel_size=1, stride=1, + bias=False)) + self.norm2: nn.BatchNorm2d + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) + self.relu2: nn.ReLU + self.add_module('relu2', nn.ReLU(inplace=True)) + self.conv2: nn.Conv2d + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, + kernel_size=3, stride=1, padding=1, + bias=False)) + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bn_function(self, inputs: List[Tensor]) -> Tensor: + concated_features = torch.cat(inputs, 1) + bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + def any_requires_grad(self, input: List[Tensor]) -> bool: + for tensor in input: + if tensor.requires_grad: + return True + return False + + @torch.jit._overload_method # noqa: F811 + def forward(self, input: List[Tensor]) -> Tensor: + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, input: Tensor) -> Tensor: + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, input: Tensor) -> Tensor: # noqa: F811 + if isinstance(input, Tensor): + prev_features = [input] + else: + prev_features = input + + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bn_function(prev_features) + + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, + training=self.training) + return new_features + + +class _DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__( + self, + num_layers: int, + num_input_features: int, + bn_size: int, + growth_rate: int, + drop_rate: float, + memory_efficient: bool = False + ) -> None: + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features: Tensor) -> Tensor: + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features: int, num_output_features: int) -> None: + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/blocks/multi_scale_depth.py b/data/ARKitScenes/depth_upsampling/models/mspf/blocks/multi_scale_depth.py new file mode 100644 index 0000000..f2ae246 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/blocks/multi_scale_depth.py @@ -0,0 +1,57 @@ +import torch.nn.functional as torch_nn_func +from torch import nn + + +class Conv2D(nn.Module): + """ + P = ((S-1)*W-S+F)/2, with F = filter size, S = stride + + """ + def __init__(self, in_channels, out_channels, bias=False, kernel_size=1, stride=1, padding=0, dilation=1, activation=None, batch_norm=None): + super(Conv2D, self).__init__() + + self.activation = activation + self.norm = batch_norm + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + bias=bias, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation) + + if self.norm is not None: + self.norm = nn.BatchNorm2d + + if self.activation is not None: + if self.activation == "relu": + self.activation = nn.ReLU() + else: + raise Exception(f"activation {self.activation} not supported") + + def forward(self, x): + out = self.conv(x) + if self.norm is not None: + out = self.norm(out) + if self.activation is not None: + out = self.activation(out) + return out + + +class Upsample2D(nn.Module): + def __init__(self, in_channels, out_channels, ratio=2): + super(Upsample2D, self).__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + bias=False, kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU() + self.ratio = ratio + + def forward(self, x): + up_x = torch_nn_func.interpolate(x, scale_factor=self.ratio, mode='nearest') + out = self.conv(up_x) + out = self.relu(out) + return out + + diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/densenet.py b/data/ARKitScenes/depth_upsampling/models/mspf/densenet.py new file mode 100644 index 0000000..d780d1b --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/densenet.py @@ -0,0 +1,103 @@ +from collections import OrderedDict +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from models.mspf.blocks.dense_net import _DenseBlock, _Transition + + +class DenseNet(nn.Module): + """ + Taken from torchvision - slightly adjusted + """ + + def __init__( + self, + growth_rate: int = 32, + block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), + num_init_features: int = 64, + bn_size: int = 4, + drop_rate: float = 0, + num_classes: int = 1000, + memory_efficient: bool = False + ) -> None: + + super(DenseNet, self).__init__() + + # First convolution + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=2, + padding=1, bias=False)), + ('norm0', nn.BatchNorm2d(num_init_features)), + ('relu0', nn.ReLU(inplace=True)), + ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition(num_input_features=num_features, + num_output_features=num_features // 2) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + + # Linear layer + self.classifier = nn.Linear(num_features, num_classes) + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + features = self.features(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = torch.flatten(out, 1) + out = self.classifier(out) + return out + + +class DenseNet121(nn.Module): + def __init__(self): + super().__init__() + model = DenseNet(32, (6, 12, 24, 16), 64) + self.base_model = model.features + self.skip_feature_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5'] + self.skip_out_channels = [64, 64, 128, 256, 1024] + + def forward(self, x): + _, _, h, w = x.shape + features = [x] + skip_feat = {"x1": x} + for k, v in self.base_model._modules.items(): + # ignore classification head + if 'fc' in k or 'avgpool' in k: + continue + feature = v(features[-1]) + features.append(feature) + if any(x in k for x in self.skip_feature_names): + _, _, fh, fw = feature.shape + skip_feat[f"x{int(h/fh)}"] = feature + return skip_feat diff --git a/data/ARKitScenes/depth_upsampling/models/mspf/mspf.py b/data/ARKitScenes/depth_upsampling/models/mspf/mspf.py new file mode 100644 index 0000000..1d65bc1 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/models/mspf/mspf.py @@ -0,0 +1,39 @@ +from torch import nn + +import dataset_keys +from models.mspf.MultiScaleDepthSR import MultiscaleDepthDecoder +from models.mspf.densenet import DenseNet121 + + +class MSPF(nn.Module): + """ + Inspired by: Multi-Scale Progressive Fusion Learning for Depth Map Super-Resolution + https://arxiv.org/pdf/2011.11865v1.pdf + + variables: + decoder_channel_output_scales: available scale factors for reducing channels + in decoder based on encoder input channels. + params: + decoder_channel_scale: used to control decoder size + + """ + + decoder_channel_output_scales = [1, 2, 4, 8, 16] + + def __init__(self, upsample_factor, decoder_channel_scale=2): + super(MSPF, self).__init__() + self.encoder = DenseNet121() + assert decoder_channel_scale in self.decoder_channel_output_scales, \ + f"decoder scale factor not supported {decoder_channel_scale} - supported {self.decoder_channel_output_scales}" + input_channels = self.encoder.skip_out_channels[::-1] + output_channels = [int(ch/decoder_channel_scale) for ch in self.encoder.skip_out_channels[::-1]] + + self.decoder = MultiscaleDepthDecoder(input_channels, output_channels, upsample_factor) + + def forward(self, batch): + rgb = batch[dataset_keys.COLOR_IMG] + rgb = (rgb / 255.0) - 0.5 + depth = batch[dataset_keys.LOW_RES_DEPTH_IMG] + skip_features = self.encoder(rgb) + output = self.decoder(depth, skip_features) + return {dataset_keys.PREDICTION_DEPTH_IMG: output} diff --git a/data/ARKitScenes/depth_upsampling/sample_vis.py b/data/ARKitScenes/depth_upsampling/sample_vis.py new file mode 100644 index 0000000..2d1929a --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/sample_vis.py @@ -0,0 +1,53 @@ +import argparse +import numpy as np +import matplotlib.pyplot as plt +from dataset import ARKitScenesDataset +from data_utils import image_chw_to_hwc +import dataset_keys + + +def sample_vis(dataset_path: str, split: str, sample_id: str, max_depth): + dataset = ARKitScenesDataset(root=dataset_path, split=split) + video_id = sample_id.split('_')[0] + idx = None + for i in range(len(dataset)): + if dataset.samples[i][0] == video_id and dataset.samples[i][1] == sample_id: + idx = i + break + if idx is None: + raise ValueError(f'Can\'t find sample from split={split}, with video_id={video_id} and sample_id={sample_id}') + sample = dataset[idx] + + max_depth = np.min([max_depth, + sample[dataset_keys.HIGH_RES_DEPTH_IMG].max(), + sample[dataset_keys.LOW_RES_DEPTH_IMG].max()]) + fig, axes = plt.subplots(2, 2) + axes[0, 0].set_title('Color img') + axes[0, 0].axis(False) + axes[0, 0].imshow(image_chw_to_hwc(sample[dataset_keys.COLOR_IMG]/255)) + axes[0, 1].set_title('High Res img (0=no depth)') + axes[0, 1].axis(False) + img = axes[0, 1].imshow(sample[dataset_keys.HIGH_RES_DEPTH_IMG][0], vmin=0, vmax=max_depth, cmap=plt.get_cmap('turbo')) + fig.colorbar(img, ax=axes[0, 1]) + axes[1, 1].set_title('Low Res img') + axes[1, 1].axis(False) + img = axes[1, 1].imshow(sample[dataset_keys.LOW_RES_DEPTH_IMG][0], vmin=0, vmax=max_depth, cmap=plt.get_cmap('turbo')) + fig.colorbar(img, ax=axes[1, 1]) + axes[1, 0].set_title('Color and low res overlay') + axes[1, 0].axis(False) + axes[1, 0].imshow(image_chw_to_hwc(sample[dataset_keys.COLOR_IMG]/255)) + axes[1, 0].imshow(sample[dataset_keys.LOW_RES_DEPTH_IMG][0], vmin=0, vmax=max_depth, cmap=plt.get_cmap('turbo'), alpha=0.5) + plt.show() + plt.waitforbuttonpress() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("data_path", type=str, help="path to the dataset") + parser.add_argument("split", choices=['train', 'val'], type=str, help="sample split (train/val)") + parser.add_argument("sample_id", type=str, help="the id of the sample") + parser.add_argument("--max_depth", type=float, default=5, help="clip the depth image to max depth [meters]") + + args = parser.parse_args() + + sample_vis(args.data_path, args.split, args.sample_id, args.max_depth) diff --git a/data/ARKitScenes/depth_upsampling/sampler.py b/data/ARKitScenes/depth_upsampling/sampler.py new file mode 100644 index 0000000..03c989a --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/sampler.py @@ -0,0 +1,36 @@ +import numpy as np +import torch.utils.data + + +class MultiEpochSampler(torch.utils.data.Sampler): + r"""Samples elements randomly over multiple epochs + + Arguments: + data_source (Dataset): dataset to sample from + num_iter (int) : Number of times to loop over the dataset + start_itr (int) : which iteration to begin from + """ + + def __init__(self, data_source, num_iter, start_itr=0, batch_size=128): + super().__init__(data_source) + self.data_source = data_source + self.dataset_size = len(self.data_source) + self.num_iter = num_iter + self.start_itr = start_itr + self.batch_size = batch_size + self.num_epochs = int(np.ceil((self.num_iter * self.batch_size) / float(self.dataset_size))) + + if not isinstance(self.dataset_size, int) or self.dataset_size <= 0: + raise ValueError("dataset size should be a positive integeral " + "value, but got dataset_size={}".format(self.dataset_size)) + + def __iter__(self): + n = self.dataset_size + # Determine number of epochs + num_epochs = int(np.ceil(((self.num_iter - self.start_itr) * self.batch_size) / float(n))) + out = np.concatenate([np.random.permutation(n) for epoch in range(self.num_epochs)])[-num_epochs * n: self.num_iter * self.batch_size] + out = out[(self.start_itr * self.batch_size % n):] + return iter(out) + + def __len__(self): + return (self.num_iter - self.start_itr) * self.batch_size diff --git a/data/ARKitScenes/depth_upsampling/train.py b/data/ARKitScenes/depth_upsampling/train.py new file mode 100755 index 0000000..7107949 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/train.py @@ -0,0 +1,162 @@ +import argparse +import os +import shlex +import subprocess +import time + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision.transforms import Compose + +import transfroms +from dataset import ARKitScenesDataset +from logs.eval import eval_log +from logs.train import train_log +from losses import get_loss +from models import get_network +from sampler import MultiEpochSampler +from data_utils import batch_to_cuda + +TENSORBOARD_DIR = 'tensorboard' + + +def main(args): + batch_size = args.batch_size + num_iter = args.num_iter + upsample_factor = args.upsample_factor + start_itr = 0 + + patch_size = 256 if args.upsample_factor == 2 else 512 + + print('loading train dataset') + transform = Compose([transfroms.RandomCrop(height=patch_size, width=patch_size, upsample_factor=upsample_factor), + transfroms.RandomFilpLR(), + transfroms.ValidDepthMask(gt_low_limit=0.01), + transfroms.AsContiguousArray()]) + train_dataset = ARKitScenesDataset(root=args.data_path, split='train', + upsample_factor=upsample_factor, transform=transform) + sampler = MultiEpochSampler(train_dataset, num_iter, start_itr, batch_size) + train_dataloader = DataLoader(train_dataset, + batch_size, + sampler=sampler, + num_workers=8 * int(torch.cuda.is_available()), + pin_memory=torch.cuda.is_available(), + drop_last=True) + + print('loading validation dataset') + transform = Compose([transfroms.ModCrop(modulo=32), + transfroms.ValidDepthMask(gt_low_limit=0.01)]) + val_dataset = ARKitScenesDataset(root=args.data_path, split='val', + upsample_factor=upsample_factor, transform=transform) + val_dataloader = DataLoader(val_dataset, + batch_size=1, + num_workers=8 * int(torch.cuda.is_available()), + pin_memory=torch.cuda.is_available()) + + print('building the network') + model = get_network(args.network, upsample_factor) + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + if torch.cuda.is_available(): + model.cuda() + model.train() + cudnn.benchmark = True + + # init logs + if args.tbp is not None: + print('starting tensorboard') + tensorboard_path = os.path.join(args.log_dir, TENSORBOARD_DIR) + command = f'tensorboard --logdir {tensorboard_path} --port {args.tbp}' + tensorboard_process = subprocess.Popen(shlex.split(command), env=os.environ.copy()) + train_tensorboard_writer = SummaryWriter(os.path.join(tensorboard_path, 'train'), flush_secs=30) + val_tensorboard_writer = SummaryWriter(os.path.join(tensorboard_path, 'val'), flush_secs=30) + else: + print('no tensorboard') + tensorboard_process = None + train_tensorboard_writer = None + val_tensorboard_writer = None + + loss_fn = get_loss(args.network) + + start_time = time.time() + step = 1 + duration = 0 + current_lr = -1 + print("start training") + for input_batch in train_dataloader: + before_op_time = time.time() + input_batch = batch_to_cuda(input_batch) + + optimizer.zero_grad() + output_batch = model(input_batch) + loss = loss_fn(output_batch, input_batch) + + if np.isnan(loss.cpu().item()): + exit('NaN in loss occurred. Aborting training.') + + loss.backward() + optimizer.step() + + duration += time.time() - before_op_time + + train_log(step=step, loss=loss, input_batch=input_batch, output_batch=output_batch, + tensorboard_writer=train_tensorboard_writer, current_lr=current_lr) + if step % args.eval_freq == 0: + eval_log(step, model, val_dataloader, val_tensorboard_writer) + + if step and step % args.log_freq == 0: + examples_per_sec = args.batch_size / duration * args.log_freq + time_sofar = (time.time() - start_time) / 3600 + training_time_left = (num_iter / step - 1.0) * time_sofar + print_string = 'examples/s: {:4.2f} | time elapsed: {:.2f}h | time left: {:.2f}h' + print(print_string.format(examples_per_sec, time_sofar, training_time_left)) + duration = 0 + + if step % args.save_freq == 0: + checkpoint = {'step': step, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict()} + save_file = os.path.join(args.log_dir, 'checkpoint_step-{}'.format(step)) + torch.save(checkpoint, save_file) + + step += 1 + + print('finished training') + if tensorboard_process is not None: + tensorboard_process.terminate() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Depth upsamling training', fromfile_prefix_chars='@') + + # data + parser.add_argument('--data_path', type=str, help='The path to the dataset', default='~/ARKitScenes') + + # Network + parser.add_argument('--network', type=str, help='network model class', required=True) + + # Losses + parser.add_argument('--loss', type=str, help='loss for training', action='append') + + # Log and save + parser.add_argument('--log_dir', type=str, help='directory to save checkpoints and summaries', default='log') + parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default=None) + parser.add_argument('--log_freq', type=int, help='log frequency in steps', default=1000) + parser.add_argument('--eval_freq', type=int, help='run evaluation frequency in steps', default=10000) + parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in steps', default=20000) + parser.add_argument('--tbp', type=int, help='tensorboard port', default=None) + + # Training + parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=5e-5) + parser.add_argument('--batch_size', type=int, help='batch size', default=16) + parser.add_argument('--num_iter', type=int, help='number of iteration to train', default=200000) + parser.add_argument('--upsample_factor', type=int, help='upsample scale from low to high resolution', + choices=[2, 4, 8]) + + args = parser.parse_args() + main(args) diff --git a/data/ARKitScenes/depth_upsampling/train.sh b/data/ARKitScenes/depth_upsampling/train.sh new file mode 100755 index 0000000..73406f8 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/train.sh @@ -0,0 +1,7 @@ +python train.py --log_dir log/MSPF_x2 --network MSPF --upsample_factor 2 +python train.py --log_dir log/MSPF_x4 --network MSPF --upsample_factor 4 +python train.py --log_dir log/MSPF_x8 --network MSPF --upsample_factor 8 + +python train.py --log_dir log/MSG_x2 --network MSG --upsample_factor 2 +python train.py --log_dir log/MSG_x4 --network MSG --upsample_factor 4 +python train.py --log_dir log/MSG_x8 --network MSG --upsample_factor 8 diff --git a/data/ARKitScenes/depth_upsampling/transfroms/__init__.py b/data/ARKitScenes/depth_upsampling/transfroms/__init__.py new file mode 100644 index 0000000..cfaa552 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/__init__.py @@ -0,0 +1,5 @@ +from .ascontiguousarray import AsContiguousArray +from .mod_crop import ModCrop +from .random_crop import RandomCrop +from .random_fliplr import RandomFilpLR +from .valid_depth_mask import ValidDepthMask diff --git a/data/ARKitScenes/depth_upsampling/transfroms/ascontiguousarray.py b/data/ARKitScenes/depth_upsampling/transfroms/ascontiguousarray.py new file mode 100644 index 0000000..29efff5 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/ascontiguousarray.py @@ -0,0 +1,12 @@ +import numpy as np + + +class AsContiguousArray: + def __init__(self): + super().__init__() + + def __call__(self, sample): + for key in sample: + if isinstance(sample[key], np.ndarray): + sample[key] = np.ascontiguousarray(sample[key]) + return sample diff --git a/data/ARKitScenes/depth_upsampling/transfroms/dilate_valid_mask.py b/data/ARKitScenes/depth_upsampling/transfroms/dilate_valid_mask.py new file mode 100644 index 0000000..340d276 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/dilate_valid_mask.py @@ -0,0 +1,15 @@ +import torch + +import dataset_keys + + +class DilateValidMask: + def __init__(self, dilation_radius: int): + self.dilation_radius = dilation_radius + self.layer = torch.nn.MaxPool2d(dilation_radius * 2 + 1, stride=1, padding=dilation_radius) + + def __call__(self, batch): + if dataset_keys.VALID_MASK_IMG in batch: + batch[dataset_keys.VALID_MASK_IMG] = self.layer((~batch[dataset_keys.VALID_MASK_IMG]).float()) == 0 + return batch + diff --git a/data/ARKitScenes/depth_upsampling/transfroms/mod_crop.py b/data/ARKitScenes/depth_upsampling/transfroms/mod_crop.py new file mode 100644 index 0000000..087bbcd --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/mod_crop.py @@ -0,0 +1,22 @@ +import dataset_keys + + +class ModCrop: + def __init__(self, modulo: int): + super().__init__() + self.modulo = modulo + + def __call__(self, sample): + img = sample[dataset_keys.COLOR_IMG] + depth = sample[dataset_keys.HIGH_RES_DEPTH_IMG] + tmpsz = depth.shape + sz = [tmpsz[1], tmpsz[2]] + sz[0] -= sz[0] % self.modulo + sz[1] -= sz[1] % self.modulo + + img = img[:, :sz[0], :sz[1]] + depth = depth[:, :sz[0], :sz[1]] + + sample[dataset_keys.COLOR_IMG] = img + sample[dataset_keys.HIGH_RES_DEPTH_IMG] = depth + return sample diff --git a/data/ARKitScenes/depth_upsampling/transfroms/random_crop.py b/data/ARKitScenes/depth_upsampling/transfroms/random_crop.py new file mode 100644 index 0000000..4025b28 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/random_crop.py @@ -0,0 +1,30 @@ +import random + +import dataset_keys + + +class RandomCrop: + def __init__(self, height: int, width: int, upsample_factor: int = None): + super().__init__() + self.height = height + self.width = width + self.upsample_factor = upsample_factor + + def __call__(self, sample): + low_res = sample[dataset_keys.LOW_RES_DEPTH_IMG].shape + low_res_patch_width = int(self.width / self.upsample_factor) + low_res_patch_height = int(self.height / self.upsample_factor) + x = random.randint(0, low_res[2] - low_res_patch_width) + y = random.randint(0, low_res[1] - low_res_patch_height) + + # crop low resolution depth image + img = sample[dataset_keys.LOW_RES_DEPTH_IMG] + img = img[:, y:y + low_res_patch_height, x:x + low_res_patch_width] + sample[dataset_keys.LOW_RES_DEPTH_IMG] = img + + # crop remaining + y *= self.upsample_factor + x *= self.upsample_factor + for key in [dataset_keys.COLOR_IMG, dataset_keys.HIGH_RES_DEPTH_IMG]: + sample[key] = sample[key][:, y:y + self.height, x:x + self.width] + return sample diff --git a/data/ARKitScenes/depth_upsampling/transfroms/random_fliplr.py b/data/ARKitScenes/depth_upsampling/transfroms/random_fliplr.py new file mode 100644 index 0000000..b3c2708 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/random_fliplr.py @@ -0,0 +1,17 @@ +import random + +import numpy as np + +import dataset_keys + + +class RandomFilpLR: + def __init__(self): + super().__init__() + + def __call__(self, sample): + if random.randint(0, 1): + asset_keys = [dataset_keys.COLOR_IMG, dataset_keys.HIGH_RES_DEPTH_IMG, dataset_keys.LOW_RES_DEPTH_IMG] + for key in asset_keys: + sample[key] = np.flip(sample[key], 2) + return sample diff --git a/data/ARKitScenes/depth_upsampling/transfroms/valid_depth_mask.py b/data/ARKitScenes/depth_upsampling/transfroms/valid_depth_mask.py new file mode 100644 index 0000000..50319b3 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/transfroms/valid_depth_mask.py @@ -0,0 +1,21 @@ +import dataset_keys + + +class ValidDepthMask: + def __init__(self, gt_low_limit: float = None, gt_high_limit: float = None): + assert gt_low_limit is None or gt_low_limit > 0, f'gt_low_limit must be greater than 0' + assert gt_high_limit is None or gt_high_limit > 0, f'gt_high_limit must be greater than 0' + self.gt_low_limit = gt_low_limit + self.gt_high_limit = gt_high_limit + + def __call__(self, sample): + if dataset_keys.VALID_MASK_IMG in sample: + valid_mask = sample[dataset_keys.VALID_MASK_IMG] + else: + valid_mask = sample[dataset_keys.HIGH_RES_DEPTH_IMG] != 0 + if self.gt_low_limit is not None: + valid_mask = valid_mask & (sample[dataset_keys.HIGH_RES_DEPTH_IMG] > self.gt_low_limit) + if self.gt_high_limit is not None: + valid_mask = valid_mask & (sample[dataset_keys.HIGH_RES_DEPTH_IMG] < self.gt_high_limit) + sample[dataset_keys.VALID_MASK_IMG] = valid_mask + return sample diff --git a/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits.csv b/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits.csv new file mode 100644 index 0000000..1c30f77 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits.csv @@ -0,0 +1,2258 @@ +video_id,visit_id,fold +41048190,381531,Training +41048223,381654,Training +41048225,381654,Training +41048229,381654,Training +41048247,381652,Training +41048249,381652,Training +41048251,381652,Training +41048262,381650,Training +41048264,381650,Training +41048265,381650,Training +42444474,421069,Training +42444477,421069,Training +42444490,421069,Training +42444499,421065,Training +42444501,421065,Training +42444503,421065,Training +42444511,421063,Training +42444512,421063,Training +42444513,421063,Training +42444514,421061,Training +42444515,421061,Training +42444517,421061,Training +42444574,421062,Training +42444579,421062,Training +42444588,421062,Training +42444692,421016,Training +42444695,421016,Training +42444696,421016,Training +42444703,421013,Training +42444706,421013,Training +42444708,421013,Training +42444709,421010,Training +42444711,421010,Training +42444712,421010,Training +42444716,421005,Training +42444719,421005,Training +42444721,421005,Training +42444733,421267,Training +42444735,421267,Training +42444738,421267,Training +42444748,421259,Training +42444750,421259,Training +42444751,421259,Training +42444754,421254,Training +42444755,421254,Training +42444758,421254,Training +42444762,421252,Training +42444767,421252,Training +42444768,421252,Training +42444787,421015,Training +42444789,421015,Training +42444791,421012,Training +42444793,421012,Training +42444794,421012,Training +42444821,421255,Training +42444822,421255,Training +42444826,421255,Training +42444858,421009,Training +42444859,421009,Training +42444860,421009,Training +42444866,421006,Training +42444869,421006,Training +42444872,421006,Training +42444873,421002,Training +42444875,421002,Training +42444876,421002,Training +42444883,421264,Training +42444885,421264,Training +42444887,421264,Training +42444891,421260,Training +42444892,421260,Training +42444896,421260,Training +42444904,421256,Training +42444907,421256,Training +42444913,421397,Training +42444916,421397,Training +42444917,421397,Training +42444923,421393,Training +42444924,421393,Training +42444928,421392,Training +42444932,421392,Training +42444933,421392,Training +42445044,421659,Training +42445047,421659,Training +42445057,421652,Training +42445060,421652,Training +42445063,421652,Training +42445078,421647,Training +42445079,421647,Training +42445081,421647,Training +42445100,422200,Training +42445103,422200,Training +42445132,420683,Training +42445135,420683,Training +42445137,420683,Training +42445168,420693,Training +42445169,420693,Training +42445173,420693,Training +42445198,420673,Training +42445205,420673,Training +42445211,420673,Training +42445428,421391,Training +42445444,421386,Training +42445445,421386,Training +42445451,421386,Training +42445476,421379,Training +42445478,421379,Training +42445479,421379,Training +42445494,421853,Training +42445498,421853,Training +42445502,421853,Training +42445584,421644,Training +42445587,421628,Training +42445592,421628,Training +42445597,421602,Training +42445599,421602,Training +42445611,422214,Training +42445612,421667,Training +42445615,422214,Training +42445619,421667,Training +42445633,421657,Training +42445639,421657,Training +42445642,421657,Training +42445670,422163,Training +42445676,422163,Training +42445680,422155,Training +42445684,422155,Training +42445689,422155,Training +42445691,421655,Training +42445692,422148,Training +42445695,421655,Training +42445697,422148,Training +42445698,421655,Training +42445707,422134,Training +42445716,422134,Training +42445718,422217,Training +42445720,422217,Training +42445721,422134,Training +42445723,422217,Training +42445728,421616,Training +42445729,421616,Training +42445736,421616,Training +42445745,421593,Training +42445758,421593,Training +42445766,421593,Training +42445769,421658,Training +42445770,421593,Training +42445771,421658,Training +42445775,422203,Training +42445781,422203,Training +42445782,421658,Training +42445783,422203,Training +42445784,421654,Training +42445785,422203,Training +42445788,421654,Training +42445790,422195,Training +42445794,421654,Training +42445796,422195,Training +42445799,422195,Training +42445802,422182,Training +42445804,422182,Training +42445806,422182,Training +42445834,422149,Training +42445862,422023,Training +42445864,422023,Training +42445865,422023,Training +42445869,422017,Training +42445872,422017,Training +42445873,422017,Training +42445877,422013,Training +42445881,422013,Training +42445882,422013,Training +42445884,422009,Training +42445888,422009,Training +42445889,422009,Training +42445891,422009,Training +42445894,422589,Training +42445902,422149,Training +42445903,422149,Training +42445913,422589,Training +42445916,422569,Training +42445922,422551,Training +42445924,422551,Training +42445927,422551,Training +42445931,422539,Training +42445938,422539,Training +42445970,422011,Training +42445987,422011,Training +42445999,422010,Training +42446008,422010,Training +42446017,422007,Training +42446031,422006,Training +42446036,422006,Training +42446039,422006,Training +42446048,422535,Training +42446050,422535,Training +42446057,422535,Training +42446061,422521,Training +42446068,422521,Training +42446080,421948,Training +42446159,422546,Training +42446164,422546,Training +42446445,422543,Training +42446450,422543,Training +42446467,422523,Training +42446468,422523,Training +42446478,422523,Training +42446492,422518,Training +42446493,422518,Training +42446495,422516,Training +42446497,422516,Training +42446558,422399,Training +42446561,422399,Training +42446574,422399,Training +42446576,422356,Training +42446579,422356,Training +42446605,422323,Training +42446607,422384,Training +42446608,422323,Training +42447199,422384,Training +42447202,423070,Training +42447203,422384,Training +42447205,423070,Training +42447210,423070,Training +42447214,422380,Training +42447221,423511,Training +42447226,423511,Training +42447230,422380,Training +42447233,422380,Training +42447275,422383,Training +42447287,422383,Training +42447294,422383,Training +42447307,422378,Training +42447308,422378,Training +42447310,422378,Training +42447320,422377,Training +42447329,422377,Training +42447336,422377,Training +42897405,423474,Training +42897409,423474,Training +42897410,423474,Training +42897418,423461,Training +42897419,423461,Training +42897421,423461,Training +42897422,423452,Training +42897426,423452,Training +42897434,423452,Training +42897436,423438,Training +42897439,423438,Training +42897442,423438,Training +42897452,422862,Training +42897455,422862,Training +42897478,423442,Training +42897479,423442,Training +42897480,422855,Training +42897482,423442,Training +42897490,422855,Training +42897509,422847,Training +42897512,422847,Training +42897523,422847,Training +42897547,422842,Training +42897560,422842,Training +42897598,422379,Training +42897600,422379,Training +42897605,422376,Training +42897607,422376,Training +42897612,422376,Training +42897631,422354,Training +42897633,422354,Training +42897634,422354,Training +42897655,422849,Training +42897675,422838,Training +42897677,422838,Training +42897681,422838,Training +42897709,423337,Training +42897712,423337,Training +42897713,423337,Training +42897720,423325,Training +42897722,423325,Training +42897732,423312,Training +42897735,423312,Training +42897736,423312,Training +42897743,423306,Training +42897744,423306,Training +42897745,423306,Training +42897755,423747,Training +42897756,423747,Training +42897771,423335,Training +42897776,423307,Training +42897777,423307,Training +42897780,423307,Training +42897783,423306,Training +42897784,423306,Training +42897785,423306,Training +42897815,423738,Training +42897818,423738,Training +42897848,423320,Training +42897851,423320,Training +42897857,423320,Training +42897863,423315,Training +42897868,423315,Training +42897877,423310,Training +42897892,423296,Training +42897898,423296,Training +42897924,423801,Training +42897925,423801,Training +42897928,423801,Training +42897930,423792,Training +42897931,423792,Training +42897934,423792,Training +42897939,423782,Training +42897945,423782,Training +42897948,423782,Training +42897955,423324,Training +42897960,423324,Training +42897967,423324,Training +42898006,423311,Training +42898007,423311,Training +42898052,423791,Training +42898057,423777,Training +42898059,423777,Training +42898061,423777,Training +42898065,423770,Training +42898067,423770,Training +42898068,423770,Training +42898070,423614,Training +42898071,423614,Training +42898075,423614,Training +42898083,423611,Training +42898087,423611,Training +42898089,423611,Training +42898094,423605,Training +42898097,423605,Training +42898098,423605,Training +42898100,423613,Training +42898109,423613,Training +42898112,423613,Training +42898123,423980,Training +42898125,423980,Training +42898132,423978,Training +42898135,423978,Training +42898141,423978,Training +42898156,426265,Training +42898160,426265,Training +42898162,426265,Training +42898163,426259,Training +42898169,426259,Training +42898182,426259,Training +42898189,426247,Training +42898191,426247,Training +42898195,426247,Training +42898221,423989,Training +42898230,423989,Training +42898234,423989,Training +42898236,423974,Training +42898247,423974,Training +42898248,423974,Training +42898332,423966,Training +42898334,423966,Training +42898337,423966,Training +42898340,423957,Training +42898342,423957,Training +42898343,423957,Training +42898345,423953,Training +42898348,423953,Training +42898391,423964,Training +42898392,423964,Training +42898393,423964,Training +42898405,423956,Training +42898407,423956,Training +42898408,423956,Training +42898447,426262,Training +42898448,426262,Training +42898449,426262,Training +42898454,426166,Training +42898458,426166,Training +42898461,426166,Training +42898470,426153,Training +42898477,426153,Training +42898500,434689,Training +42898501,434689,Training +42898502,434689,Training +42898510,434687,Training +42898511,434687,Training +42898526,426168,Training +42898551,426156,Training +42898555,426156,Training +42898558,426156,Training +42898560,434700,Training +42898571,434700,Training +42898577,434700,Training +42898586,434695,Training +42898587,434695,Training +42898596,434695,Training +42898738,426154,Training +42898745,426154,Training +42898750,426150,Training +42898751,426150,Training +42898752,426150,Training +42898754,426146,Training +42898756,426146,Training +42898760,426146,Training +42898762,426143,Training +42898765,426143,Training +42898768,426143,Training +42898802,434688,Training +42898803,434688,Training +42898808,434688,Training +42899109,434969,Training +42899111,434969,Training +42899112,434969,Training +42899118,434964,Training +42899119,434964,Training +42899120,434964,Training +42899125,434959,Training +42899126,434959,Training +42899128,434959,Training +42899132,434955,Training +42899137,434955,Training +42899139,434955,Training +42899140,434952,Training +42899141,434952,Training +42899144,434952,Training +42899148,434909,Training +42899153,434909,Training +42899154,434909,Training +42899163,434897,Training +42899165,434897,Training +42899175,434892,Training +42899177,434892,Training +42899178,434892,Training +42899184,434888,Training +42899185,434888,Training +42899187,434888,Training +42899198,435330,Training +42899202,435330,Training +42899209,435330,Training +42899210,435327,Training +42899211,435325,Training +42899214,435325,Training +42899215,435325,Training +42899216,435324,Training +42899220,435324,Training +42899221,435324,Training +42899259,434903,Training +42899287,435327,Training +42899292,435327,Training +42899295,435327,Training +42899432,434872,Training +42899433,434872,Training +42899435,434872,Training +42899445,435329,Training +42899448,435329,Training +42899452,435328,Training +42899453,435328,Training +42899455,435328,Training +42899489,435359,Training +42899491,435359,Training +42899513,435359,Training +42899526,435358,Training +42899529,435358,Training +42899531,435358,Training +42899533,435358,Training +42899540,435353,Training +42899542,435353,Training +42899543,435353,Training +42899550,435345,Training +42899559,435345,Training +42899562,435345,Training +42899598,435715,Training +42899600,435715,Training +42899603,435715,Training +42899624,435357,Training +42899630,435357,Training +42899632,435357,Training +42899650,435729,Training +42899653,435729,Training +42899654,435729,Training +42899657,435723,Training +42899664,435723,Training +42899666,435718,Training +42899668,435718,Training +42899670,435718,Training +42899774,435363,Training +42899775,435363,Training +42899778,435363,Training +42899780,435363,Training +42899789,435342,Training +42899795,435342,Training +42899797,435342,Training +42899817,435730,Training +42899819,435730,Training +42899821,435730,Training +42899822,435724,Training +42899828,435724,Training +42899829,435724,Training +42899968,435664,Training +42899970,435664,Training +42899975,435664,Training +42899977,435661,Training +42923203,435661,Training +42923204,435661,Training +42923207,435659,Training +42923208,435659,Training +42923211,435659,Training +42923225,435654,Training +42923226,435654,Training +43649384,437126,Training +43649391,437126,Training +43649393,437126,Training +43649397,436998,Training +43649399,436998,Training +43649403,436998,Training +43649408,436997,Training +43649409,436997,Training +43649417,436997,Training +43649421,436996,Training +43649425,436996,Training +43649426,436996,Training +43649440,437135,Training +43649451,437135,Training +43649452,437135,Training +43649459,437134,Training +43649463,437134,Training +43649464,437134,Training +43649480,437125,Training +43649484,437125,Training +43649586,437169,Training +43649589,437169,Training +43649590,437169,Training +43649597,437168,Training +43649598,437168,Training +43649603,437168,Training +43649605,437167,Training +43649609,437167,Training +43649612,437167,Training +43649613,437165,Training +43649614,437165,Training +43649615,437165,Training +43649633,437154,Training +43649634,437154,Training +43649639,437154,Training +43649643,437253,Training +43649647,437253,Training +43649648,437253,Training +43649654,437252,Training +43649660,437252,Training +43649662,437252,Training +43649680,437159,Training +43649681,437159,Training +43649685,437159,Training +43649686,437157,Training +43649688,437157,Training +43649692,437157,Training +43649707,437221,Training +43649714,437221,Training +43649720,437221,Training +43649722,437219,Training +43649727,437219,Training +43649729,437219,Training +43649732,437218,Training +43649742,437218,Training +43649743,437218,Training +43649754,437298,Training +43649755,437298,Training +43649756,437298,Training +43649762,437294,Training +43649763,437294,Training +43649767,437294,Training +43649772,437220,Training +43649774,437220,Training +43649775,437220,Training +43649787,437210,Training +43649794,437210,Training +43649798,437210,Training +43649802,437299,Training +43649804,437299,Training +43828138,437452,Training +43828142,437452,Training +43828144,437452,Training +43828149,437450,Training +43828154,437450,Training +43828162,437445,Training +43828230,437448,Training +43828231,437448,Training +43828233,437448,Training +43828245,437429,Training +43828250,437429,Training +43828270,437546,Training +43828282,437546,Training +43828307,437430,Training +43828323,437424,Training +43828334,437544,Training +43828339,437544,Training +43828340,437544,Training +43828345,437543,Training +43828350,437543,Training +43828351,437543,Training +43828361,437539,Training +43828362,437539,Training +43828366,437534,Training +43828368,437534,Training +43828369,437534,Training +43828380,437533,Training +43828381,437533,Training +43828382,437533,Training +43828390,437709,Training +43828391,437709,Training +43828393,437709,Training +43828402,437530,Training +43828404,437530,Training +43828421,437528,Training +43828423,437528,Training +43828435,437524,Training +43828437,437524,Training +43828441,437521,Training +43828442,437521,Training +43828449,437521,Training +43828455,437520,Training +43828456,437520,Training +43828457,437520,Training +43828460,437711,Training +43828463,437711,Training +43828464,437711,Training +43828476,437692,Training +43828478,437692,Training +43895956,437986,Training +43895960,437986,Training +43895969,437986,Training +43896024,437985,Training +43896035,437985,Training +43896055,438029,Training +43896063,438029,Training +43896089,438122,Training +43896169,438017,Training +43896170,438017,Training +43896177,438118,Training +43896183,438118,Training +43896186,438116,Training +43896187,438116,Training +43896188,438116,Training +43896199,438115,Training +43896223,438109,Training +43896224,438109,Training +43896229,438109,Training +43896244,438102,Training +43896247,438102,Training +43896249,438102,Training +43896251,438101,Training +43896256,438101,Training +44358167,438779,Training +44358170,438775,Training +44358173,438775,Training +44358176,438775,Training +44358220,438796,Training +44358222,438796,Training +44358225,438796,Training +44358227,438796,Training +44358230,438795,Training +44358234,438795,Training +44358235,438795,Training +44358238,438794,Training +44358241,438794,Training +44358243,438794,Training +44358250,438804,Training +44358252,438804,Training +44358253,438804,Training +44358256,438802,Training +44358257,438802,Training +44358258,438802,Training +44358259,438802,Training +44358263,438801,Training +44358268,438801,Training +44358270,438801,Training +44358274,438793,Training +44358278,438793,Training +44358281,438793,Training +44358289,438966,Training +44358291,438966,Training +44358295,438966,Training +44358324,442392,Training +44358327,442392,Training +44358328,442392,Training +44358340,438968,Training +44358341,438968,Training +44358345,438968,Training +44358360,438967,Training +44358361,438967,Training +44358378,442394,Training +44358380,442394,Training +44358386,442394,Training +44358406,455340,Training +44358407,455340,Training +44358410,455340,Training +44358418,455339,Training +44358419,455339,Training +44358422,455339,Training +44358423,455337,Training +44358426,455337,Training +44358428,455337,Training +44358471,455342,Training +44358472,455342,Training +44358617,464236,Training +44358620,464236,Training +44358634,464761,Training +44358645,464761,Training +44358647,464761,Training +44358660,464233,Training +44358673,464233,Training +44358691,464230,Training +44358693,464230,Training +44358698,464230,Training +44358700,464762,Training +44358707,464762,Training +44358710,464762,Training +44358728,464758,Training +44358729,464758,Training +44358731,464758,Training +44796189,464755,Training +44796190,464755,Training +44796191,464755,Training +44796252,464746,Training +44796254,464746,Training +44796257,464746,Training +44796271,464754,Training +44796279,464754,Training +44796280,464754,Training +44796282,464754,Training +44796304,464803,Training +44796310,464803,Training +44796311,464803,Training +44796324,464789,Training +44796326,464789,Training +44796332,464785,Training +44796340,464785,Training +44796343,464785,Training +44796358,464795,Training +44796366,464795,Training +44796367,464795,Training +44796374,464792,Training +44796379,464792,Training +44796387,464792,Training +44796438,464981,Training +44796443,464981,Training +44796446,464981,Training +44796450,464977,Training +44796451,464977,Training +44796455,464977,Training +44796473,466111,Training +44796475,466111,Training +44796477,466111,Training +44796481,464982,Training +44796483,464982,Training +44796484,464982,Training +44796485,464980,Training +44796487,464980,Training +44796488,464980,Training +44796491,464979,Training +44796492,464979,Training +44796498,464979,Training +44796505,464975,Training +44796507,464975,Training +44796517,466112,Training +44796520,466112,Training +44796521,466112,Training +44796562,466092,Training +44796568,466092,Training +44796575,466162,Training +44796576,466162,Training +44796579,466162,Training +44796584,466160,Training +44796585,466160,Training +44796596,466159,Training +44796598,466159,Training +44796709,466242,Training +44796722,466242,Training +44796744,466238,Training +44796752,466234,Training +45260889,466433,Training +45260890,466433,Training +45260891,466433,Training +45260892,466433,Training +45260895,466433,Training +45260946,466176,Training +45260947,466176,Training +45260951,466437,Training +45260952,466437,Training +45260957,466437,Training +45261080,466656,Training +45261087,466656,Training +45261089,466656,Training +45261096,466652,Training +45261099,466652,Training +45261100,466652,Training +45261108,466652,Training +45261153,466644,Training +45261158,466644,Training +45261164,466644,Training +45261167,466637,Training +45261169,466637,Training +45261172,466637,Training +45261201,466741,Training +45261203,466741,Training +45261206,466741,Training +45261219,466717,Training +45261228,466717,Training +45261230,466717,Training +45261241,466964,Training +45261246,466964,Training +45261248,466964,Training +45261251,466750,Training +45261252,466750,Training +45261257,466750,Training +45261263,466739,Training +45261265,466739,Training +45261269,466739,Training +45261275,466731,Training +45261277,466731,Training +45261281,466731,Training +45261288,466967,Training +45261289,466967,Training +45261290,466967,Training +45261292,466965,Training +45261294,466965,Training +45261297,466965,Training +45261313,466857,Training +45261314,466857,Training +45261317,466828,Training +45261323,466828,Training +45261326,466828,Training +45261329,467253,Training +45261331,467253,Training +45261332,467253,Training +45261339,466896,Training +45261343,466896,Training +45261344,466896,Training +45261355,466879,Training +45261357,466879,Training +45261378,466869,Training +45261381,466869,Training +45261382,466869,Training +45261384,466857,Training +45261390,466857,Training +45261392,466857,Training +45261395,466832,Training +45261397,466832,Training +45261399,466832,Training +45261406,467252,Training +45261407,467252,Training +45261413,467252,Training +45261452,467285,Training +45261455,467285,Training +45261457,467285,Training +45261485,467105,Training +45261486,467105,Training +45261490,467105,Training +45261495,467102,Training +45261499,467102,Training +45261502,467102,Training +45261507,467287,Training +45261509,467287,Training +45261515,467287,Training +45261522,467283,Training +45261524,467283,Training +45261527,467283,Training +45261532,467334,Training +45261533,467334,Training +45261541,467334,Training +45261544,467326,Training +45261547,467326,Training +45261548,467326,Training +45261556,467324,Training +45261562,467324,Training +45261565,467312,Training +45261569,467312,Training +45261571,467312,Training +45261598,467330,Training +45261600,467330,Training +45261601,467330,Training +45261641,468402,Training +45261645,468402,Training +45261651,468393,Training +45261653,468393,Training +45261658,468393,Training +45261662,468479,Training +45261665,468479,Training +45261666,468479,Training +45261672,468477,Training +45261675,468477,Training +45261677,468477,Training +45261681,468476,Training +45261682,468476,Training +45261686,468476,Training +45261688,468410,Training +45261689,468410,Training +45261695,468410,Training +45261709,468397,Training +45261710,468397,Training +45261713,468397,Training +45261722,468380,Training +45261725,468380,Training +45261728,468380,Training +45662791,468509,Training +45662796,468509,Training +45662800,468509,Training +45662805,468498,Training +45662806,468498,Training +45662809,468498,Training +45662812,468485,Training +45662813,468485,Training +45662816,468485,Training +45662826,468605,Training +45662833,468602,Training +45662835,468602,Training +45662839,468602,Training +45662870,468496,Training +45662874,468496,Training +45662875,468496,Training +45662882,468603,Training +45662884,468603,Training +45662886,468603,Training +45662890,468659,Training +45662891,468659,Training +45662897,468659,Training +45662905,468654,Training +45662908,468654,Training +45662912,468654,Training +45662950,468650,Training +45662951,468650,Training +45662956,468650,Training +45662959,468649,Training +45662960,468649,Training +45662962,468649,Training +45662995,468779,Training +45662996,468779,Training +45662998,468779,Training +45663007,468771,Training +45663009,468771,Training +45663044,468781,Training +45663051,468781,Training +45663053,468775,Training +45663064,468775,Training +45663065,468775,Training +45663223,469235,Training +45663226,469235,Training +45663227,469235,Training +45663248,469260,Training +45663250,469260,Training +45663255,469260,Training +45663257,469238,Training +45663258,469238,Training +45663262,469238,Training +45663312,469419,Training +45663313,469419,Training +45663317,469419,Training +45663322,469403,Training +45663330,469403,Training +45663332,469403,Training +45663340,469552,Training +45663342,469552,Training +45663343,469552,Training +45663347,469449,Training +45663349,469449,Training +45663363,469449,Training +45663368,469434,Training +45663376,469434,Training +45663377,469434,Training +45663397,469560,Training +45663398,469560,Training +45663401,469560,Training +45663404,469556,Training +45663409,469556,Training +45663415,469556,Training +47115118,469635,Training +47115126,469635,Training +47115127,469635,Training +47115129,469619,Training +47115131,469619,Training +47115132,469619,Training +47115135,469617,Training +47115136,469617,Training +47115137,469617,Training +47115143,469607,Training +47115148,469607,Training +47115150,469607,Training +47115154,469603,Training +47115156,469603,Training +47115158,469603,Training +47115161,469601,Training +47115165,469601,Training +47115166,469601,Training +47115180,469627,Training +47115194,469600,Training +47115195,469600,Training +47115197,469600,Training +47115198,469822,Training +47115199,469822,Training +47115201,469822,Training +47115205,469807,Training +47115216,469807,Training +47115217,469807,Training +47115222,469800,Training +47115227,469800,Training +47115235,469800,Training +47115238,470052,Training +47115240,470052,Training +47115244,470052,Training +47115255,470039,Training +47115256,470039,Training +47115258,470039,Training +47115276,470053,Training +47115287,470053,Training +47115299,470046,Training +47115304,470046,Training +47115321,470334,Training +47115322,470334,Training +47115334,470313,Training +47115336,470313,Training +47115338,470313,Training +47115361,470325,Training +47115362,470325,Training +47115363,470325,Training +47115379,470308,Training +47115384,470308,Training +47115412,470445,Training +47115415,470445,Training +47115422,470439,Training +47115427,470439,Training +47115442,470515,Training +47115445,470515,Training +47115448,470515,Training +47115501,470523,Training +47115504,470523,Training +47115508,470523,Training +47115573,470944,Training +47115574,470944,Training +47115575,470944,Training +47115619,470804,Training +47115621,470804,Training +47115623,470804,Training +47115624,470945,Training +47115628,470945,Training +47115634,470945,Training +47115648,470942,Training +47115649,470942,Training +47115650,470942,Training +47204271,470998,Training +47204274,470998,Training +47204275,470998,Training +47204289,470975,Training +47204291,470975,Training +47204296,470975,Training +47204303,471103,Training +47204304,471103,Training +47204305,471103,Training +47204316,471098,Training +47204319,471098,Training +47204325,471014,Training +47204329,471014,Training +47204330,471014,Training +47204335,471004,Training +47204341,471004,Training +47204342,471004,Training +47204350,470995,Training +47204362,470995,Training +47204363,470995,Training +47204387,471172,Training +47204390,471172,Training +47204392,471172,Training +47204398,471160,Training +47204449,471164,Training +47204451,471164,Training +47204453,471164,Training +47204473,471146,Training +47204482,471146,Training +47204493,471304,Training +47204494,471304,Training +47204499,471304,Training +47204505,471663,Training +47204506,471663,Training +47204517,471663,Training +47204519,471655,Training +47204522,471655,Training +47204524,471655,Training +47204529,471650,Training +47204533,471650,Training +47204535,471650,Training +47204585,471656,Training +47204589,471656,Training +47204591,471656,Training +47204724,472331,Training +47204726,472331,Training +47204730,472331,Training +47204734,472312,Training +47204736,472312,Training +47204739,472312,Training +47204741,472484,Training +47204744,472484,Training +47204747,472484,Training +47204757,472475,Training +47204761,472475,Training +47204763,472475,Training +47204780,472473,Training +47204783,472473,Training +47204786,472473,Training +47204791,472335,Training +47204795,472335,Training +47204800,472335,Training +47204813,472306,Training +47204816,472306,Training +47204818,472306,Training +47204820,472477,Training +47204824,472477,Training +47204828,472477,Training +47204832,472603,Training +47204836,472603,Training +47204840,472582,Training +47204842,472582,Training +47204843,472582,Training +47204850,472575,Training +47204852,472575,Training +47204856,472575,Training +47204857,472575,Training +47204858,472575,Training +47204861,478020,Training +47204863,478020,Training +47204864,478020,Training +47204874,478016,Training +47204876,478016,Training +47204877,478016,Training +47204887,472595,Training +47204888,472595,Training +47204892,472595,Training +47204919,478018,Training +47204920,478018,Training +47204921,478018,Training +47330997,470341,Training +47331002,470341,Training +47331003,470341,Training +47331008,470326,Training +47331010,470326,Training +47331039,470318,Training +47331040,470318,Training +47331041,470318,Training +47331048,470309,Training +47331049,470309,Training +47331053,470309,Training +47331133,470348,Training +47331134,470348,Training +47331136,470348,Training +47331139,470329,Training +47331140,470329,Training +47331143,470329,Training +47331144,470316,Training +47331146,470316,Training +47331148,470316,Training +47331159,470310,Training +47331173,471419,Training +47331174,471419,Training +47331175,471419,Training +47331183,471414,Training +47331240,471374,Training +47331273,468298,Training +47331276,468298,Training +47331279,468282,Training +47331281,468282,Training +47331283,468282,Training +47331292,468275,Training +47331294,468275,Training +47331298,468275,Training +47331355,470546,Training +47331356,470546,Training +47331358,470546,Training +47331388,470499,Training +47331390,470499,Training +47331393,470499,Training +47331495,470493,Training +47331496,470493,Training +47331498,470493,Training +47331518,470551,Training +47331520,470551,Training +47331522,470551,Training +47331524,470544,Training +47331525,470544,Training +47331527,470544,Training +47331533,466908,Training +47331536,466908,Training +47331539,466908,Training +47331541,466895,Training +47331542,466895,Training +47331544,466895,Training +47331549,466884,Training +47331552,466884,Training +47331557,466884,Training +47331558,466876,Training +47331560,466876,Training +47331561,466876,Training +47331573,466850,Training +47331575,466850,Training +47331576,466850,Training +47331587,466841,Training +47331589,466841,Training +47331591,466841,Training +47331604,466835,Training +47331606,466835,Training +47331607,466835,Training +47331615,466916,Training +47331617,466916,Training +47331618,466916,Training +47331619,466906,Training +47331621,466906,Training +47331622,466906,Training +47331625,466899,Training +47331627,466899,Training +47331631,466899,Training +47331684,466911,Training +47331686,466911,Training +47331707,466880,Training +47331710,466880,Training +47331711,466880,Training +47331762,466834,Training +47331765,466834,Training +47331766,466834,Training +47331770,466831,Training +47331771,466831,Training +47331772,466831,Training +47331793,467254,Training +47331795,467254,Training +47331796,467254,Training +47331931,470891,Training +47331938,470868,Training +47331942,470868,Training +47331945,470863,Training +47331946,470863,Training +47331947,470863,Training +47332015,470876,Training +47332016,470876,Training +47332046,470830,Training +47332048,470830,Training +47332049,470830,Training +47332066,470977,Training +47332072,470977,Training +47332074,470977,Training +47332075,470974,Training +47332079,470974,Training +47332080,470974,Training +47332081,471106,Training +47332083,471106,Training +47332084,471106,Training +47332152,471179,Training +47332156,471179,Training +47332161,471176,Training +47332195,471148,Training +47332197,471148,Training +47332198,471148,Training +47332297,471685,Training +47332298,471685,Training +47332301,471685,Training +47332311,471674,Training +47332312,471674,Training +47332314,471674,Training +47332320,471648,Training +47332401,472042,Training +47332415,472030,Training +47332417,472030,Training +47332418,472030,Training +47332480,472334,Training +47332482,472334,Training +47332487,472326,Training +47332489,472326,Training +47332490,472326,Training +47332503,472308,Training +47332505,472308,Training +47332526,472349,Training +47332529,472349,Training +47332530,472349,Training +47332593,472626,Training +47332597,472621,Training +47332598,472621,Training +47332599,472621,Training +47332618,472591,Training +47332622,472591,Training +47332698,469269,Training +47332703,469269,Training +47332704,469269,Training +47332705,469267,Training +47332746,469278,Training +47332748,469278,Training +47332751,469278,Training +47332757,469268,Training +47332758,469268,Training +47332759,469268,Training +47332764,469264,Training +47332766,469264,Training +47332769,469264,Training +47332774,469256,Training +47332777,469256,Training +47332778,469256,Training +47332792,469226,Training +47332795,469226,Training +47332796,469226,Training +47332808,469220,Training +47332810,469220,Training +47332812,469220,Training +47332817,469218,Training +47332818,469218,Training +47332819,469218,Training +47332833,469271,Training +47332834,469271,Training +47332837,469271,Training +47332937,469217,Training +47332939,469217,Training +47332940,469217,Training +47332946,469216,Training +47332950,469216,Training +47332951,469216,Training +47332952,469214,Training +47332954,469214,Training +47332956,469214,Training +47332983,469464,Training +47332986,469459,Training +47332987,469459,Training +47332988,469459,Training +47333031,469467,Training +47333033,469467,Training +47333035,469467,Training +47333040,469460,Training +47333043,469460,Training +47333045,469460,Training +47333053,469455,Training +47333055,469455,Training +47333056,469455,Training +47333058,469446,Training +47333060,469446,Training +47333061,469446,Training +47333122,469424,Training +47333125,469424,Training +47333133,469461,Training +47333134,469461,Training +47333135,469461,Training +47333138,469456,Training +47333139,469456,Training +47333152,469456,Training +47333153,469452,Training +47333155,469452,Training +47333159,469452,Training +47333164,469447,Training +47333165,469447,Training +47333166,469447,Training +47333168,469443,Training +47333174,469443,Training +47333175,469443,Training +47333188,469413,Training +47333189,469413,Training +47333194,469413,Training +47333195,469413,Training +47333199,469413,Training +47333201,469413,Training +47333203,469413,Training +47333205,469411,Training +47333206,469411,Training +47333209,469411,Training +47333222,469406,Training +47333224,469406,Training +47333225,469406,Training +47333233,467181,Training +47333239,467181,Training +47333241,467181,Training +47333243,467160,Training +47333244,467160,Training +47333249,467160,Training +47333264,467179,Training +47333265,467179,Training +47333268,467179,Training +47333274,467170,Training +47333277,467170,Training +47333281,467152,Training +47333283,467152,Training +47333288,467152,Training +47333292,467139,Training +47333293,467139,Training +47333298,467139,Training +47333308,467115,Training +47333310,467115,Training +47333319,467115,Training +47333323,467109,Training +47333324,467109,Training +47333325,467109,Training +47333329,467106,Training +47333331,467106,Training +47333332,467106,Training +47333369,467166,Training +47333370,467166,Training +47333372,467166,Training +47333468,467138,Training +47333472,467135,Training +47333473,467135,Training +47333475,467135,Training +47333477,467128,Training +47333479,467128,Training +47333480,467128,Training +47333482,467125,Training +47333484,467125,Training +47333486,467125,Training +47333487,467125,Training +47333497,467119,Training +47333512,467175,Training +47333513,467175,Training +47333514,467175,Training +47333521,467159,Training +47333561,469650,Training +47333562,469650,Training +47333564,469650,Training +47333565,469644,Training +47333574,469624,Training +47333589,469612,Training +47333590,469612,Training +47333593,469612,Training +47333603,469659,Training +47333610,469649,Training +47333612,469649,Training +47333618,469649,Training +47333623,469642,Training +47333624,469642,Training +47333625,469642,Training +47333635,469654,Training +47333636,469654,Training +47333637,469654,Training +47333647,469647,Training +47333659,469616,Training +47333660,469616,Training +47333665,469614,Training +47333666,469614,Training +47333668,469614,Training +47333774,467370,Training +47333776,467370,Training +47333778,467370,Training +47333785,467360,Training +47333786,467360,Training +47333787,467360,Training +47333801,467291,Training +47333803,467291,Training +47333806,467291,Training +47333810,468077,Training +47333812,468077,Training +47333815,468077,Training +47333826,468074,Training +47333827,468074,Training +47333828,468074,Training +47333836,468071,Training +47333838,468071,Training +47333840,468071,Training +47333846,467367,Training +47333847,467367,Training +47333851,467367,Training +47333862,467357,Training +47333863,467357,Training +47333864,467357,Training +47333865,467357,Training +47333870,467350,Training +47333873,467350,Training +47333875,467350,Training +47333876,467346,Training +47333882,467346,Training +47333884,467346,Training +47333888,467341,Training +47333890,467341,Training +47333892,467341,Training +47333945,467351,Training +47333946,467351,Training +47333949,467351,Training +47333960,467345,Training +47333963,467345,Training +47333964,467344,Training +47333965,467344,Training +47333973,467337,Training +47333974,467337,Training +47333975,467337,Training +47333988,467304,Training +47333994,467304,Training +47333996,467304,Training +47334000,467297,Training +47334001,467297,Training +47334005,467296,Training +47334046,469837,Training +47334050,469837,Training +47334052,469837,Training +47334055,469830,Training +47334057,469830,Training +47334058,469830,Training +47334066,469825,Training +47334068,469825,Training +47334069,469825,Training +47334070,469825,Training +47334074,469819,Training +47334080,469819,Training +47334130,469835,Training +47334131,469835,Training +47334151,469804,Training +47334153,469804,Training +47334154,469804,Training +47334161,470138,Training +47334163,470138,Training +47334165,470130,Training +47334169,470130,Training +47334194,470111,Training +47334195,470111,Training +47334480,468107,Training +47334482,468107,Training +47334483,468107,Training +47334493,468146,Training +47334498,468146,Training +47334499,468146,Training +47334514,468140,Training +47334515,468140,Training +47334522,468136,Training +47334523,468136,Training +47334525,468136,Training +47334539,468115,Training +47334540,468115,Training +47334543,468115,Training +47334544,468111,Training +47334547,468111,Training +47334548,468111,Training +47334552,468108,Training +47334557,468108,Training +47334558,468108,Training +47334559,468104,Training +47334565,468104,Training +47334566,468104,Training +47334609,481529,Training +47334627,481523,Training +47334628,481523,Training +47334629,481523,Training +47334649,481494,Training +47334652,481494,Training +47334655,481494,Training +47334666,481487,Training +47334669,481487,Training +47334672,481487,Training +47334768,481801,Training +47334770,481801,Training +47334791,481772,Training +47334808,481770,Training +47334810,481770,Training +47429646,470340,Training +47429650,470340,Training +47429663,470335,Training +47429664,470335,Training +47429670,470327,Training +47429671,470327,Training +47429677,470327,Training +47429682,470322,Training +47429688,470315,Training +47429689,470315,Training +47429694,470314,Training +47429717,471376,Training +47429720,471376,Training +47429721,471376,Training +47429723,471376,Training +47429725,471369,Training +47429728,471369,Training +47429729,471369,Training +47429736,471364,Training +47429738,471364,Training +47429739,471364,Training +47429750,471429,Training +47429752,471429,Training +47429753,471429,Training +47429756,471422,Training +47429758,471422,Training +47429762,471422,Training +47429765,471416,Training +47429766,471416,Training +47429771,471411,Training +47429773,471411,Training +47429774,471411,Training +47429786,471389,Training +47429787,471389,Training +47429788,471389,Training +47429841,471431,Training +47429844,471431,Training +47429846,471431,Training +47429851,471424,Training +47429854,471424,Training +47429855,471424,Training +47429860,471418,Training +47429861,471418,Training +47429867,471418,Training +47429870,471410,Training +47429871,471410,Training +47429874,471410,Training +47429888,471367,Training +47429889,471367,Training +47429890,471367,Training +47430056,470808,Training +47430058,470808,Training +47430060,470808,Training +47430065,470807,Training +47430067,470807,Training +47430068,470807,Training +47430074,470883,Training +47430077,470883,Training +47430089,470874,Training +47430092,470872,Training +47430094,470872,Training +47430097,470872,Training +47430100,470867,Training +47430101,470867,Training +47430103,470867,Training +47430107,470819,Training +47430111,470819,Training +47430114,470819,Training +47430127,470814,Training +47430132,470814,Training +47430137,470812,Training +47430140,470812,Training +47430141,470812,Training +47430155,471031,Training +47430189,470997,Training +47430192,470997,Training +47430193,470997,Training +47430262,470985,Training +47430269,470985,Training +47430271,470985,Training +47430280,470980,Training +47430287,470980,Training +47430288,470980,Training +47430294,470978,Training +47430298,470978,Training +47430408,471156,Training +47430410,471156,Training +47430414,471156,Training +47430439,471192,Training +47430441,471192,Training +47430442,471192,Training +47430449,471184,Training +47430452,471184,Training +47430455,471184,Training +47430519,471676,Training +47430525,471673,Training +47430530,471673,Training +47430531,471673,Training +47430536,471670,Training +47430554,471666,Training +47430565,472076,Training +47430566,472076,Training +47430568,472076,Training +47430637,472353,Training +47430640,472353,Training +47430646,472348,Training +47430661,472342,Training +47430679,472316,Training +47430683,472316,Training +47430690,472310,Training +47430692,472310,Training +47430694,472310,Training +47430696,472307,Training +47430702,472307,Training +47430722,472630,Training +47430724,472630,Training +47430726,472630,Training +47430732,472625,Training +47430739,472622,Training +47430740,472622,Training +47430785,472577,Training +47430795,469646,Training +47430800,469646,Training +47430802,469646,Training +47430805,469638,Training +47430808,469638,Training +47430809,469638,Training +47430814,469632,Training +47430818,469632,Training +47430820,469632,Training +47430822,469622,Training +47430824,469622,Training +47430828,469622,Training +47430831,469618,Training +47430832,469618,Training +47430834,469618,Training +47430839,469615,Training +47430840,469615,Training +47430844,469615,Training +47430845,469615,Training +47430854,469606,Training +47430855,469606,Training +47430857,469606,Training +47430858,469606,Training +47430860,469606,Training +47430862,469606,Training +47430894,469749,Training +47430896,469749,Training +47430898,469749,Training +47430900,469744,Training +47430901,469744,Training +47430904,469744,Training +47430924,469831,Training +47430928,469831,Training +47430931,469831,Training +47430934,469828,Training +47430935,469828,Training +47430937,469828,Training +47430939,469824,Training +47430940,469824,Training +47430942,469824,Training +47430944,469821,Training +47430948,469821,Training +47430949,469821,Training +47430950,469821,Training +47430962,469795,Training +47430964,469795,Training +47430966,469795,Training +47431030,470110,Training +47431034,470110,Training +47431071,468296,Training +47431072,468296,Training +47431073,468296,Training +47431076,468296,Training +47431078,468296,Training +47431081,468285,Training +47431082,468285,Training +47431083,468285,Training +47431088,468283,Training +47431090,468283,Training +47431091,468283,Training +47431092,468283,Training +47431103,481527,Training +47431106,481527,Training +47431107,481527,Training +47431153,481485,Training +47431155,481485,Training +47431173,481480,Training +47669851,481678,Training +47669854,481678,Training +47669856,481678,Training +47669874,481513,Training +47669876,481513,Training +47669878,481513,Training +47669880,481508,Training +47669881,481508,Training +47669896,481508,Training +47669901,481488,Training +47669903,481488,Training +47669907,481488,Training +47669909,481477,Training +47669911,481477,Training +47669913,481477,Training +47669917,481670,Training +47669925,481670,Training +47669927,481670,Training +47669934,481664,Training +47669936,481664,Training +47669937,481664,Training +47669939,481659,Training +47669943,481659,Training +47669948,481659,Training +47669968,482120,Training +47669978,482090,Training +47669980,482090,Training +47669984,482090,Training +47670001,482104,Training +47670002,482104,Training +47670003,482104,Training +47670008,482102,Training +47670011,482102,Training +47670023,482231,Training +47670028,482231,Training +47670031,482231,Training +47670038,482230,Training +47670039,482230,Training +47670040,482230,Training +47670053,482224,Training +47670063,482315,Training +47670064,482315,Training +47670066,482315,Training +47670073,482306,Training +47670074,482306,Training +47670075,482306,Training +47670100,482484,Training +47670105,482484,Training +47670109,482484,Training +47670133,482294,Training +47670134,482294,Training +47670136,482294,Training +47670149,482293,Training +47670150,482293,Training +47670171,482489,Training +47670174,482489,Training +47670182,482470,Training +47670183,482470,Training +47670184,482769,Training +47670191,482769,Training +47670195,482761,Training +47670197,482761,Training +47670201,482761,Training +47670208,482755,Training +47670218,482755,Training +47670224,482755,Training +47670232,482763,Training +47670233,482763,Training +47670234,482763,Training +47670242,482860,Training +47670244,482860,Training +47670250,482860,Training +47670269,482980,Training +47670270,482980,Training +47670271,482980,Training +47670288,482873,Training +47670291,482873,Training +47670293,482873,Training +47670295,482858,Training +47670299,482858,Training +47670305,482858,Training +47670311,482857,Training +47670315,482857,Training +47670318,482857,Training +47670323,482992,Training +47670324,482992,Training +47670328,482992,Training +47670334,482984,Training +47670340,482984,Training +47670346,482984,Training +47670350,482977,Training +47670356,482977,Training +47670358,482977,Training +47895179,471687,Training +47895184,471687,Training +47895186,471687,Training +47895187,471677,Training +47895191,471677,Training +47895192,471677,Training +47895196,471672,Training +47895197,471672,Training +47895199,471672,Training +47895201,471664,Training +47895206,471664,Training +47895207,471664,Training +47895221,472065,Training +47895231,472058,Training +47895233,472058,Training +47895249,472050,Training +47895250,472050,Training +47895263,472026,Training +47895285,472201,Training +47895286,472201,Training +47895325,472333,Training +47895446,478025,Training +47895448,478025,Training +47895478,478011,Training +47895480,478011,Training +47895482,478011,Training +47895520,482115,Training +47895622,482323,Training +47895653,482135,Training +47895659,482135,Training +47895661,482134,Training +47895662,482134,Training +47895678,482129,Training +47895680,482129,Training +47895718,482095,Training +47895719,482095,Training +47895720,482095,Training +47895891,482592,Training +47895899,482592,Training +47895900,482591,Training +47895903,482591,Training +47895904,482591,Training +47895909,482587,Training +47895910,482587,Training +47895911,482587,Training +47895978,482609,Training +47895979,482609,Training +47895983,482609,Training +48017884,482866,Training +48018010,482851,Training +48018067,483107,Training +48018068,483107,Training +48018069,483107,Training +48018071,483107,Training +48018079,483089,Training +48018081,483089,Training +48018084,483089,Training +48018087,483087,Training +48018089,483087,Training +48018090,483087,Training +48018149,483067,Training +48018155,483066,Training +48018249,483076,Training +48018250,483076,Training +48018252,483076,Training +48018311,483322,Training +48018312,483322,Training +48018315,483322,Training +48018317,483317,Training +48018318,483317,Training +48018320,483317,Training +48018331,483304,Training +48018334,483304,Training +48018404,483297,Training +48018405,483297,Training +48018584,483590,Training +48018746,483595,Training +48018750,483595,Training +48018751,483595,Training +48018757,484003,Training +48018758,484003,Training +48018760,484003,Training +48018776,483987,Training +48018778,483987,Training +48018780,483987,Training +48018785,483959,Training +48018788,483959,Training +48018789,483959,Training +48018790,483959,Training +48018791,483959,Training +48018800,483954,Training +48018802,483954,Training +48018804,483954,Training +48018808,483952,Training +48018809,483952,Training +48018811,483952,Training +48018815,483950,Training +48018818,483950,Training +48458202,483075,Training +48458203,483075,Training +48458206,483075,Training +48458211,483073,Training +48458218,483073,Training +48458219,483073,Training +48458221,483071,Training +48458225,483071,Training +48458228,483071,Training +48458231,483247,Training +48458235,483247,Training +48458245,483247,Training +48458249,483244,Training +48458251,483244,Training +48458252,483244,Training +48458259,483084,Training +48458260,483084,Training +48458262,483084,Training +48458282,483249,Training +48458283,483249,Training +48458286,483249,Training +48458295,483245,Training +48458297,483245,Training +48458301,483245,Training +48458535,484276,Training +48458541,484276,Training +48458566,484249,Training +48458569,484249,Training +48458572,484249,Training +48458582,484445,Training +48458584,484445,Training +48458587,484445,Training +48458637,484452,Training +48458639,484452,Training +48458640,484452,Training +48458675,484715,Training +48458678,484715,Training +48458688,484715,Training +48458690,484568,Training +48458700,484568,Training +48458703,484563,Training +48458704,484563,Training +48458706,484563,Training +48458716,484562,Training +48458717,484562,Training +48458719,484562,Training +41069021,381658,Validation +41069025,381658,Validation +41069042,381649,Validation +41069043,381649,Validation +41069046,381649,Validation +41069048,381644,Validation +41069050,381644,Validation +41069051,381644,Validation +41142278,384651,Validation +41142280,384651,Validation +41142281,384651,Validation +42444946,421337,Validation +42444949,421337,Validation +42444950,421337,Validation +42444966,421383,Validation +42444968,421383,Validation +42444976,421383,Validation +42445021,421380,Validation +42445022,421380,Validation +42445026,421380,Validation +42445028,421378,Validation +42445029,421378,Validation +42445031,421378,Validation +42445429,421372,Validation +42445441,421372,Validation +42445448,421372,Validation +42445991,422024,Validation +42446038,422015,Validation +42446049,422015,Validation +42446100,422034,Validation +42446103,422034,Validation +42446114,422034,Validation +42446156,422018,Validation +42446163,422014,Validation +42446165,422014,Validation +42446167,422014,Validation +42446517,422391,Validation +42446519,422391,Validation +42446522,422391,Validation +42446527,422386,Validation +42446529,422386,Validation +42446532,422386,Validation +42446533,422382,Validation +42446535,422382,Validation +42446536,422382,Validation +42446540,422381,Validation +42446541,422381,Validation +42897501,422851,Validation +42897504,422851,Validation +42897508,422851,Validation +42897521,422829,Validation +42897526,422829,Validation +42897528,422829,Validation +42897541,422826,Validation +42897542,422826,Validation +42897545,422813,Validation +42897549,422813,Validation +42897550,422813,Validation +42897552,422803,Validation +42897554,422803,Validation +42897559,422803,Validation +42897561,422785,Validation +42897564,422785,Validation +42897566,422785,Validation +42897599,423477,Validation +42897647,423448,Validation +42897651,423448,Validation +42897667,423448,Validation +42897688,423441,Validation +42898811,434650,Validation +42898816,434650,Validation +42898817,434650,Validation +42898818,434641,Validation +42898822,434641,Validation +42898826,434641,Validation +42898849,434659,Validation +42898854,434659,Validation +42898862,434659,Validation +42899461,435385,Validation +42899471,435385,Validation +42899611,435374,Validation +42899612,435374,Validation +42899613,435374,Validation +42899619,435368,Validation +42899620,435368,Validation +44358442,460419,Validation +44358446,460419,Validation +44358448,460419,Validation +44358451,460417,Validation +44358452,460417,Validation +44358455,460417,Validation +45260854,466193,Validation +45260856,466193,Validation +45260857,466193,Validation +45260898,466192,Validation +45260899,466192,Validation +45260900,466192,Validation +45260920,466183,Validation +45260925,466183,Validation +45260928,466183,Validation +45261121,466628,Validation +45261128,466628,Validation +45261133,466803,Validation +45261140,466803,Validation +45261142,466803,Validation +45261143,466801,Validation +45261144,466801,Validation +45261150,466801,Validation +45261179,466802,Validation +45261181,466802,Validation +45261182,466802,Validation +45261185,466801,Validation +45261190,466801,Validation +45261193,466801,Validation +45261575,468079,Validation +45261581,468079,Validation +45261582,468079,Validation +45261588,468073,Validation +45261594,468073,Validation +45261615,467293,Validation +45261619,467293,Validation +45261620,467293,Validation +45261631,468076,Validation +45261632,468076,Validation +45261637,468076,Validation +45662921,468646,Validation +45662924,468646,Validation +45662926,468646,Validation +45662942,468709,Validation +45662943,468709,Validation +45662944,468709,Validation +45662970,468712,Validation +45662975,468712,Validation +45662979,468712,Validation +45662981,468711,Validation +45662983,468711,Validation +45662987,468711,Validation +45663113,469013,Validation +45663114,469013,Validation +45663115,469013,Validation +45663149,469021,Validation +45663150,469021,Validation +45663154,469021,Validation +45663164,469011,Validation +45663165,469011,Validation +45663175,469011,Validation +47115452,470661,Validation +47115460,470661,Validation +47115463,470661,Validation +47115469,470652,Validation +47115473,470652,Validation +47115474,470652,Validation +47204552,471952,Validation +47204554,471952,Validation +47204556,471952,Validation +47204559,471948,Validation +47204563,471948,Validation +47204566,471948,Validation +47204573,471940,Validation +47204575,471940,Validation +47204578,471940,Validation +47204605,471942,Validation +47204607,471942,Validation +47204609,471942,Validation +47331068,470352,Validation +47331069,470352,Validation +47331071,470352,Validation +47331262,468267,Validation +47331265,468267,Validation +47331266,468267,Validation +47331319,468286,Validation +47331322,468286,Validation +47331324,468286,Validation +47331336,469319,Validation +47331337,469319,Validation +47331339,469319,Validation +47331651,466849,Validation +47331653,466849,Validation +47331654,466849,Validation +47331661,466846,Validation +47331662,466846,Validation +47331988,470811,Validation +47331989,470811,Validation +47331990,470811,Validation +47332000,470806,Validation +47332004,470806,Validation +47332005,470806,Validation +47332885,469272,Validation +47332886,469272,Validation +47332890,469272,Validation +47332893,469266,Validation +47332895,469266,Validation +47332899,469266,Validation +47332901,469259,Validation +47332904,469259,Validation +47332905,469259,Validation +47332908,469255,Validation +47332910,469255,Validation +47332911,469255,Validation +47332915,469249,Validation +47332916,469249,Validation +47332918,469249,Validation +47333440,467176,Validation +47333441,467176,Validation +47333443,467176,Validation +47333452,467172,Validation +47333456,467172,Validation +47333898,467314,Validation +47333899,467314,Validation +47333904,467314,Validation +47333916,467311,Validation +47333918,467311,Validation +47333920,467311,Validation +47333923,467306,Validation +47333924,467306,Validation +47333925,467306,Validation +47333927,467305,Validation +47333931,467305,Validation +47333932,467305,Validation +47333934,467301,Validation +47333937,467301,Validation +47333940,467301,Validation +47334105,469784,Validation +47334239,470101,Validation +47334240,470101,Validation +47334241,470101,Validation +47334256,470098,Validation +47429912,471428,Validation +47429922,471425,Validation +47429971,470550,Validation +47429977,470550,Validation +47429992,470543,Validation +47429995,470543,Validation +47429998,470541,Validation +47430001,470541,Validation +47430002,470541,Validation +47430003,470537,Validation +47430005,470537,Validation +47430023,470516,Validation +47430024,470516,Validation +47430026,470516,Validation +47430033,470512,Validation +47430034,470512,Validation +47430036,470512,Validation +47430038,470508,Validation +47430043,470508,Validation +47430045,470508,Validation +47430047,470507,Validation +47430048,470507,Validation +47430051,470507,Validation +47895341,472297,Validation +47895348,472297,Validation +47895350,472297,Validation +47895355,472487,Validation +47895364,472483,Validation +47895365,472483,Validation +47895552,482083,Validation +47895554,482083,Validation +47895556,482083,Validation +47895779,482296,Validation +47895783,482296,Validation +48018367,483313,Validation +48018368,483313,Validation +48018372,483313,Validation +48018379,483312,Validation +48018559,483621,Validation +48018560,483621,Validation +48018562,483621,Validation +48018566,483620,Validation +48018571,483620,Validation +48018572,483620,Validation +48018730,483632,Validation +48018732,483632,Validation +48018733,483632,Validation +48018737,483618,Validation +48018739,483618,Validation +48018741,483618,Validation +48458481,483953,Validation +48458484,483953,Validation +48458489,483953,Validation +48458647,484543,Validation +48458650,484543,Validation +48458656,484540,Validation +48458657,484540,Validation +48458660,484534,Validation +48458663,484534,Validation +48458667,484534,Validation diff --git a/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits_400.csv b/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits_400.csv new file mode 100644 index 0000000..6245da4 --- /dev/null +++ b/data/ARKitScenes/depth_upsampling/upsampling_train_val_splits_400.csv @@ -0,0 +1,687 @@ +video_id,visit_id,fold +41048190,381531,Training +41048223,381654,Training +41048225,381654,Training +41048229,381654,Training +41048247,381652,Training +41048249,381652,Training +41048251,381652,Training +41048262,381650,Training +41048264,381650,Training +41048265,381650,Training +42444474,421069,Training +42444477,421069,Training +42444490,421069,Training +42444499,421065,Training +42444501,421065,Training +42444503,421065,Training +42444511,421063,Training +42444512,421063,Training +42444513,421063,Training +42444514,421061,Training +42444515,421061,Training +42444517,421061,Training +42444574,421062,Training +42444579,421062,Training +42444588,421062,Training +42444692,421016,Training +42444695,421016,Training +42444696,421016,Training +42444703,421013,Training +42444706,421013,Training +42444708,421013,Training +42444709,421010,Training +42444711,421010,Training +42444712,421010,Training +42444716,421005,Training +42444719,421005,Training +42444721,421005,Training +42444733,421267,Training +42444735,421267,Training +42444738,421267,Training +42444748,421259,Training +42444750,421259,Training +42444751,421259,Training +42444754,421254,Training +42444755,421254,Training +42444758,421254,Training +42444762,421252,Training +42444767,421252,Training +42444768,421252,Training +42444787,421015,Training +42444789,421015,Training +42444791,421012,Training +42444793,421012,Training +42444794,421012,Training +42444821,421255,Training +42444822,421255,Training +42444826,421255,Training +42444858,421009,Training +42444859,421009,Training +42444860,421009,Training +42444866,421006,Training +42444869,421006,Training +42444872,421006,Training +42444873,421002,Training +42444875,421002,Training +42444876,421002,Training +42444883,421264,Training +42444885,421264,Training +42444887,421264,Training +42444891,421260,Training +42444892,421260,Training +42444896,421260,Training +42444904,421256,Training +42444907,421256,Training +42444913,421397,Training +42444916,421397,Training +42444917,421397,Training +42444923,421393,Training +42444924,421393,Training +42444928,421392,Training +42444932,421392,Training +42444933,421392,Training +42445044,421659,Training +42445047,421659,Training +42445057,421652,Training +42445060,421652,Training +42445063,421652,Training +42445078,421647,Training +42445079,421647,Training +42445081,421647,Training +42445100,422200,Training +42445103,422200,Training +42445132,420683,Training +42445135,420683,Training +42445137,420683,Training +42445168,420693,Training +42445169,420693,Training +42445173,420693,Training +42445198,420673,Training +42445205,420673,Training +42445211,420673,Training +42445428,421391,Training +42445444,421386,Training +42445445,421386,Training +42445451,421386,Training +42445476,421379,Training +42445478,421379,Training +42445479,421379,Training +42445494,421853,Training +42445498,421853,Training +42445502,421853,Training +42445584,421644,Training +42445587,421628,Training +42445592,421628,Training +42445597,421602,Training +42445599,421602,Training +42445611,422214,Training +42445612,421667,Training +42445615,422214,Training +42445619,421667,Training +42445633,421657,Training +42445639,421657,Training +42445642,421657,Training +42445670,422163,Training +42445676,422163,Training +42445680,422155,Training +42445684,422155,Training +42445689,422155,Training +42445691,421655,Training +42445692,422148,Training +42445695,421655,Training +42445697,422148,Training +42445698,421655,Training +42445707,422134,Training +42445716,422134,Training +42445718,422217,Training +42445720,422217,Training +42445721,422134,Training +42445723,422217,Training +42445728,421616,Training +42445729,421616,Training +42445736,421616,Training +42445745,421593,Training +42445758,421593,Training +42445766,421593,Training +42445769,421658,Training +42445770,421593,Training +42445771,421658,Training +42445775,422203,Training +42445781,422203,Training +42445782,421658,Training +42445783,422203,Training +42445784,421654,Training +42445785,422203,Training +42445788,421654,Training +42445790,422195,Training +42445794,421654,Training +42445796,422195,Training +42445799,422195,Training +42445802,422182,Training +42445804,422182,Training +42445806,422182,Training +42445834,422149,Training +42445862,422023,Training +42445864,422023,Training +42445865,422023,Training +42445869,422017,Training +42445872,422017,Training +42445873,422017,Training +42445877,422013,Training +42445881,422013,Training +42445882,422013,Training +42445884,422009,Training +42445888,422009,Training +42445889,422009,Training +42445891,422009,Training +42445894,422589,Training +42445902,422149,Training +42445903,422149,Training +42445913,422589,Training +42445916,422569,Training +42445922,422551,Training +42445924,422551,Training +42445927,422551,Training +42445931,422539,Training +42445938,422539,Training +42445970,422011,Training +42445987,422011,Training +42445999,422010,Training +42446008,422010,Training +42446017,422007,Training +42446031,422006,Training +42446036,422006,Training +42446039,422006,Training +42446048,422535,Training +42446050,422535,Training +42446057,422535,Training +42446061,422521,Training +42446068,422521,Training +42446080,421948,Training +42446159,422546,Training +42446164,422546,Training +42446445,422543,Training +42446450,422543,Training +42446467,422523,Training +42446468,422523,Training +42446478,422523,Training +42446492,422518,Training +42446493,422518,Training +42446495,422516,Training +42446497,422516,Training +42446558,422399,Training +42446561,422399,Training +42446574,422399,Training +42446576,422356,Training +42446579,422356,Training +42446605,422323,Training +42446607,422384,Training +42446608,422323,Training +42447199,422384,Training +42447202,423070,Training +42447203,422384,Training +42447205,423070,Training +42447210,423070,Training +42447214,422380,Training +42447221,423511,Training +42447226,423511,Training +42447230,422380,Training +42447233,422380,Training +42447275,422383,Training +42447287,422383,Training +42447294,422383,Training +42447307,422378,Training +42447308,422378,Training +42447310,422378,Training +42447320,422377,Training +42447329,422377,Training +42447336,422377,Training +42897405,423474,Training +42897409,423474,Training +42897410,423474,Training +42897418,423461,Training +42897419,423461,Training +42897421,423461,Training +42897422,423452,Training +42897426,423452,Training +42897434,423452,Training +42897436,423438,Training +42897439,423438,Training +42897442,423438,Training +42897452,422862,Training +42897455,422862,Training +42897478,423442,Training +42897479,423442,Training +42897480,422855,Training +42897482,423442,Training +42897490,422855,Training +42897509,422847,Training +42897512,422847,Training +42897523,422847,Training +42897547,422842,Training +42897560,422842,Training +42897598,422379,Training +42897600,422379,Training +42897605,422376,Training +42897607,422376,Training +42897612,422376,Training +42897631,422354,Training +42897633,422354,Training +42897634,422354,Training +42897655,422849,Training +42897675,422838,Training +42897677,422838,Training +42897681,422838,Training +42897709,423337,Training +42897712,423337,Training +42897713,423337,Training +42897720,423325,Training +42897722,423325,Training +42897732,423312,Training +42897735,423312,Training +42897736,423312,Training +42897743,423306,Training +42897744,423306,Training +42897745,423306,Training +42897755,423747,Training +42897756,423747,Training +42897771,423335,Training +42897776,423307,Training +42897777,423307,Training +42897780,423307,Training +42897783,423306,Training +42897784,423306,Training +42897785,423306,Training +42897815,423738,Training +42897818,423738,Training +42897848,423320,Training +42897851,423320,Training +42897857,423320,Training +42897863,423315,Training +42897868,423315,Training +42897877,423310,Training +42897892,423296,Training +42897898,423296,Training +42897924,423801,Training +42897925,423801,Training +42897928,423801,Training +42897930,423792,Training +42897931,423792,Training +42897934,423792,Training +42897939,423782,Training +42897945,423782,Training +42897948,423782,Training +42897955,423324,Training +42897960,423324,Training +42897967,423324,Training +42898006,423311,Training +42898007,423311,Training +42898052,423791,Training +42898057,423777,Training +42898059,423777,Training +42898061,423777,Training +42898065,423770,Training +42898067,423770,Training +42898068,423770,Training +42898070,423614,Training +42898071,423614,Training +42898075,423614,Training +42898083,423611,Training +42898087,423611,Training +42898089,423611,Training +42898094,423605,Training +42898097,423605,Training +42898098,423605,Training +42898100,423613,Training +42898109,423613,Training +42898112,423613,Training +42898123,423980,Training +42898125,423980,Training +42898132,423978,Training +42898135,423978,Training +42898141,423978,Training +42898156,426265,Training +42898160,426265,Training +42898162,426265,Training +42898163,426259,Training +42898169,426259,Training +42898182,426259,Training +42898189,426247,Training +42898191,426247,Training +42898195,426247,Training +42898221,423989,Training +42898230,423989,Training +42898234,423989,Training +42898236,423974,Training +42898247,423974,Training +42898248,423974,Training +42898332,423966,Training +42898334,423966,Training +42898337,423966,Training +42898340,423957,Training +42898342,423957,Training +42898343,423957,Training +42898345,423953,Training +42898348,423953,Training +42898391,423964,Training +42898392,423964,Training +42898393,423964,Training +42898405,423956,Training +42898407,423956,Training +42898408,423956,Training +42898447,426262,Training +42898448,426262,Training +42898449,426262,Training +42898454,426166,Training +42898458,426166,Training +42898461,426166,Training +42898470,426153,Training +42898477,426153,Training +42898500,434689,Training +42898501,434689,Training +42898502,434689,Training +42898510,434687,Training +42898511,434687,Training +42898526,426168,Training +42898551,426156,Training +42898555,426156,Training +42898558,426156,Training +42898560,434700,Training +42898571,434700,Training +42898577,434700,Training +42898586,434695,Training +42898587,434695,Training +42898596,434695,Training +42898738,426154,Training +42898745,426154,Training +42898750,426150,Training +42898751,426150,Training +42898752,426150,Training +41069021,381658,Validation +41069025,381658,Validation +41069042,381649,Validation +41069043,381649,Validation +41069046,381649,Validation +41069048,381644,Validation +41069050,381644,Validation +41069051,381644,Validation +41142278,384651,Validation +41142280,384651,Validation +41142281,384651,Validation +42444946,421337,Validation +42444949,421337,Validation +42444950,421337,Validation +42444966,421383,Validation +42444968,421383,Validation +42444976,421383,Validation +42445021,421380,Validation +42445022,421380,Validation +42445026,421380,Validation +42445028,421378,Validation +42445029,421378,Validation +42445031,421378,Validation +42445429,421372,Validation +42445441,421372,Validation +42445448,421372,Validation +42445991,422024,Validation +42446038,422015,Validation +42446049,422015,Validation +42446100,422034,Validation +42446103,422034,Validation +42446114,422034,Validation +42446156,422018,Validation +42446163,422014,Validation +42446165,422014,Validation +42446167,422014,Validation +42446517,422391,Validation +42446519,422391,Validation +42446522,422391,Validation +42446527,422386,Validation +42446529,422386,Validation +42446532,422386,Validation +42446533,422382,Validation +42446535,422382,Validation +42446536,422382,Validation +42446540,422381,Validation +42446541,422381,Validation +42897501,422851,Validation +42897504,422851,Validation +42897508,422851,Validation +42897521,422829,Validation +42897526,422829,Validation +42897528,422829,Validation +42897541,422826,Validation +42897542,422826,Validation +42897545,422813,Validation +42897549,422813,Validation +42897550,422813,Validation +42897552,422803,Validation +42897554,422803,Validation +42897559,422803,Validation +42897561,422785,Validation +42897564,422785,Validation +42897566,422785,Validation +42897599,423477,Validation +42897647,423448,Validation +42897651,423448,Validation +42897667,423448,Validation +42897688,423441,Validation +42898811,434650,Validation +42898816,434650,Validation +42898817,434650,Validation +42898818,434641,Validation +42898822,434641,Validation +42898826,434641,Validation +42898849,434659,Validation +42898854,434659,Validation +42898862,434659,Validation +42899461,435385,Validation +42899471,435385,Validation +42899611,435374,Validation +42899612,435374,Validation +42899613,435374,Validation +42899619,435368,Validation +42899620,435368,Validation +44358442,460419,Validation +44358446,460419,Validation +44358448,460419,Validation +44358451,460417,Validation +44358452,460417,Validation +44358455,460417,Validation +45260854,466193,Validation +45260856,466193,Validation +45260857,466193,Validation +45260898,466192,Validation +45260899,466192,Validation +45260900,466192,Validation +45260920,466183,Validation +45260925,466183,Validation +45260928,466183,Validation +45261121,466628,Validation +45261128,466628,Validation +45261133,466803,Validation +45261140,466803,Validation +45261142,466803,Validation +45261143,466801,Validation +45261144,466801,Validation +45261150,466801,Validation +45261179,466802,Validation +45261181,466802,Validation +45261182,466802,Validation +45261185,466801,Validation +45261190,466801,Validation +45261193,466801,Validation +45261575,468079,Validation +45261581,468079,Validation +45261582,468079,Validation +45261588,468073,Validation +45261594,468073,Validation +45261615,467293,Validation +45261619,467293,Validation +45261620,467293,Validation +45261631,468076,Validation +45261632,468076,Validation +45261637,468076,Validation +45662921,468646,Validation +45662924,468646,Validation +45662926,468646,Validation +45662942,468709,Validation +45662943,468709,Validation +45662944,468709,Validation +45662970,468712,Validation +45662975,468712,Validation +45662979,468712,Validation +45662981,468711,Validation +45662983,468711,Validation +45662987,468711,Validation +45663113,469013,Validation +45663114,469013,Validation +45663115,469013,Validation +45663149,469021,Validation +45663150,469021,Validation +45663154,469021,Validation +45663164,469011,Validation +45663165,469011,Validation +45663175,469011,Validation +47115452,470661,Validation +47115460,470661,Validation +47115463,470661,Validation +47115469,470652,Validation +47115473,470652,Validation +47115474,470652,Validation +47204552,471952,Validation +47204554,471952,Validation +47204556,471952,Validation +47204559,471948,Validation +47204563,471948,Validation +47204566,471948,Validation +47204573,471940,Validation +47204575,471940,Validation +47204578,471940,Validation +47204605,471942,Validation +47204607,471942,Validation +47204609,471942,Validation +47331068,470352,Validation +47331069,470352,Validation +47331071,470352,Validation +47331262,468267,Validation +47331265,468267,Validation +47331266,468267,Validation +47331319,468286,Validation +47331322,468286,Validation +47331324,468286,Validation +47331336,469319,Validation +47331337,469319,Validation +47331339,469319,Validation +47331651,466849,Validation +47331653,466849,Validation +47331654,466849,Validation +47331661,466846,Validation +47331662,466846,Validation +47331988,470811,Validation +47331989,470811,Validation +47331990,470811,Validation +47332000,470806,Validation +47332004,470806,Validation +47332005,470806,Validation +47332885,469272,Validation +47332886,469272,Validation +47332890,469272,Validation +47332893,469266,Validation +47332895,469266,Validation +47332899,469266,Validation +47332901,469259,Validation +47332904,469259,Validation +47332905,469259,Validation +47332908,469255,Validation +47332910,469255,Validation +47332911,469255,Validation +47332915,469249,Validation +47332916,469249,Validation +47332918,469249,Validation +47333440,467176,Validation +47333441,467176,Validation +47333443,467176,Validation +47333452,467172,Validation +47333456,467172,Validation +47333898,467314,Validation +47333899,467314,Validation +47333904,467314,Validation +47333916,467311,Validation +47333918,467311,Validation +47333920,467311,Validation +47333923,467306,Validation +47333924,467306,Validation +47333925,467306,Validation +47333927,467305,Validation +47333931,467305,Validation +47333932,467305,Validation +47333934,467301,Validation +47333937,467301,Validation +47333940,467301,Validation +47334105,469784,Validation +47334239,470101,Validation +47334240,470101,Validation +47334241,470101,Validation +47334256,470098,Validation +47429912,471428,Validation +47429922,471425,Validation +47429971,470550,Validation +47429977,470550,Validation +47429992,470543,Validation +47429995,470543,Validation +47429998,470541,Validation +47430001,470541,Validation +47430002,470541,Validation +47430003,470537,Validation +47430005,470537,Validation +47430023,470516,Validation +47430024,470516,Validation +47430026,470516,Validation +47430033,470512,Validation +47430034,470512,Validation +47430036,470512,Validation +47430038,470508,Validation +47430043,470508,Validation +47430045,470508,Validation +47430047,470507,Validation +47430048,470507,Validation +47430051,470507,Validation +47895341,472297,Validation +47895348,472297,Validation +47895350,472297,Validation +47895355,472487,Validation +47895364,472483,Validation +47895365,472483,Validation +47895552,482083,Validation +47895554,482083,Validation +47895556,482083,Validation +47895779,482296,Validation +47895783,482296,Validation +48018367,483313,Validation +48018368,483313,Validation +48018372,483313,Validation +48018379,483312,Validation +48018559,483621,Validation +48018560,483621,Validation +48018562,483621,Validation +48018566,483620,Validation +48018571,483620,Validation +48018572,483620,Validation +48018730,483632,Validation +48018732,483632,Validation +48018733,483632,Validation +48018737,483618,Validation +48018739,483618,Validation +48018741,483618,Validation +48458481,483953,Validation +48458484,483953,Validation +48458489,483953,Validation +48458647,484543,Validation +48458650,484543,Validation +48458656,484540,Validation +48458657,484540,Validation +48458660,484534,Validation +48458663,484534,Validation +48458667,484534,Validation diff --git a/data/ARKitScenes/download_data.py b/data/ARKitScenes/download_data.py new file mode 100644 index 0000000..3d8d5d1 --- /dev/null +++ b/data/ARKitScenes/download_data.py @@ -0,0 +1,279 @@ +import argparse +import subprocess +import pandas as pd +import math +import os + +ARkitscense_url = 'https://docs-assets.developer.apple.com/ml-research/datasets/arkitscenes/v1' +TRAINING = 'Training' +VALIDATION = 'Validation' +HIGRES_DEPTH_ASSET_NAME = 'highres_depth' +POINT_CLOUDS_FOLDER = 'laser_scanner_point_clouds' + +default_raw_dataset_assets = ['mov', 'annotation', 'mesh', 'confidence', 'highres_depth', 'lowres_depth', + 'lowres_wide.traj', 'lowres_wide', 'lowres_wide_intrinsics', 'ultrawide', + 'ultrawide_intrinsics', 'vga_wide', 'vga_wide_intrinsics'] + +missing_3dod_assets_video_ids = ['47334522', '47334523', '42897421', '45261582', '47333152', '47333155', + '48458535', '48018733', '47429677', '48458541', '42897848', '47895482', + '47333960', '47430089', '42899148', '42897612', '42899153', '42446164', + '48018149', '47332198', '47334515', '45663223', '45663226', '45663227'] + + +def raw_files(video_id, assets, metadata): + file_names = [] + for asset in assets: + if HIGRES_DEPTH_ASSET_NAME == asset: + in_upsampling = metadata.loc[metadata['video_id'] == float(video_id), ['is_in_upsampling']].iat[0, 0] + if not in_upsampling: + print(f"Skipping asset {asset} for video_id {video_id} - Video not in upsampling dataset") + continue # highres_depth asset only available for video ids from upsampling dataset + + if asset in ['confidence', 'highres_depth', 'lowres_depth', 'lowres_wide', 'lowres_wide_intrinsics', + 'ultrawide', 'ultrawide_intrinsics', 'vga_wide', 'vga_wide_intrinsics']: + file_names.append(asset + '.zip') + elif asset == 'mov': + file_names.append(f'{video_id}.mov') + elif asset == 'mesh': + if video_id not in missing_3dod_assets_video_ids: + file_names.append(f'{video_id}_3dod_mesh.ply') + elif asset == 'annotation': + if video_id not in missing_3dod_assets_video_ids: + file_names.append(f'{video_id}_3dod_annotation.json') + elif asset == 'lowres_wide.traj': + if video_id not in missing_3dod_assets_video_ids: + file_names.append('lowres_wide.traj') + else: + raise Exception(f'No asset = {asset} in raw dataset') + return file_names + + +def download_file(url, file_name, dst): + os.makedirs(dst, exist_ok=True) + filepath = os.path.join(dst, file_name) + + if not os.path.isfile(filepath): + command = f"curl {url} -o {file_name}.tmp --fail" + print(f"Downloading file {filepath}") + try: + subprocess.check_call(command, shell=True, cwd=dst) + except Exception as error: + print(f'Error downloading {url}, error: {error}') + return False + os.rename(filepath+".tmp", filepath) + else: + print(f'WARNING: skipping download of existing file: {filepath}') + return True + + +def unzip_file(file_name, dst, keep_zip=True): + filepath = os.path.join(dst, file_name) + print(f"Unzipping zip file {filepath}") + command = f"unzip -oq {filepath} -d {dst}" + try: + subprocess.check_call(command, shell=True) + except Exception as error: + print(f'Error unzipping {filepath}, error: {error}') + return False + if not keep_zip: + os.remove(filepath) + return True + + +def download_laser_scanner_point_clouds_for_video(video_id, metadata, download_dir): + video_metadata = metadata.loc[metadata['video_id'] == float(video_id)] + visit_id = video_metadata['visit_id'].iat[0] + has_laser_scanner_point_clouds = video_metadata['has_laser_scanner_point_clouds'].iat[0] + + if not has_laser_scanner_point_clouds: + print(f"Warning: Laser scanner point clouds for video {video_id} are not available") + return + + if math.isnan(visit_id) or not visit_id.is_integer(): + print(f"Warning: Downloading laser scanner point clouds for video {video_id} failed - Bad visit id {visit_id}") + return + + visit_id = int(visit_id) # Expecting an 8 digit integer + laser_scanner_point_clouds_ids = laser_scanner_point_clouds_for_visit_id(visit_id, download_dir) + + for point_cloud_id in laser_scanner_point_clouds_ids: + download_laser_scanner_point_clouds(point_cloud_id, visit_id, download_dir) + + +def laser_scanner_point_clouds_for_visit_id(visit_id, download_dir): + point_cloud_to_visit_id_mapping_filename = "laser_scanner_point_clouds_mapping.csv" + if not os.path.exists(point_cloud_to_visit_id_mapping_filename): + point_cloud_to_visit_id_mapping_url = \ + f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{point_cloud_to_visit_id_mapping_filename}" + if not download_file(point_cloud_to_visit_id_mapping_url, + point_cloud_to_visit_id_mapping_filename, + download_dir): + print( + f"Error downloading point cloud for visit_id {visit_id} at location " + f"{point_cloud_to_visit_id_mapping_url}") + return [] + + point_cloud_to_visit_id_mapping_filepath = os.path.join(download_dir, point_cloud_to_visit_id_mapping_filename) + point_cloud_to_visit_id_mapping = pd.read_csv(point_cloud_to_visit_id_mapping_filepath) + point_cloud_ids = point_cloud_to_visit_id_mapping.loc[ + point_cloud_to_visit_id_mapping['visit_id'] == visit_id, ["laser_scanner_point_clouds_id"] + ] + point_cloud_ids_list = [scan_id[0] for scan_id in point_cloud_ids.values] + + return point_cloud_ids_list + + +def download_laser_scanner_point_clouds(laser_scanner_point_cloud_id, visit_id, download_dir): + laser_scanner_point_clouds_folder_path = os.path.join(download_dir, POINT_CLOUDS_FOLDER, str(visit_id)) + os.makedirs(laser_scanner_point_clouds_folder_path, exist_ok=True) + + for extension in [".ply", "_pose.txt"]: + filename = f"{laser_scanner_point_cloud_id}{extension}" + filepath = os.path.join(laser_scanner_point_clouds_folder_path, filename) + if os.path.exists(filepath): + return + file_url = f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{visit_id}/{filename}" + download_file(file_url, filename, laser_scanner_point_clouds_folder_path) + + +def get_metadata(dataset, download_dir): + filename = "metadata.csv" + url = f"{ARkitscense_url}/threedod/{filename}" if '3dod' == dataset else f"{ARkitscense_url}/{dataset}/{filename}" + dst_folder = os.path.join(download_dir, dataset) + dst_file = os.path.join(dst_folder, filename) + + if not download_file(url, filename, dst_folder): + return + + metadata = pd.read_csv(dst_file) + return metadata + + +def download_data(dataset, + video_ids, + dataset_splits, + download_dir, + keep_zip, + raw_dataset_assets, + should_download_laser_scanner_point_cloud, + ): + metadata = get_metadata(dataset, download_dir) + if None is metadata: + print(f"Error retrieving metadata for dataset {dataset}") + return + + download_dir = os.path.abspath(download_dir) + for video_id in sorted(set(video_ids)): + split = dataset_splits[video_ids.index(video_id)] + dst_dir = os.path.join(download_dir, dataset, split) + if dataset == 'raw': + url_prefix = "" + file_names = [] + if not raw_dataset_assets: + print(f"Warning: No raw assets given for video id {video_id}") + else: + dst_dir = os.path.join(dst_dir, str(video_id)) + url_prefix = f"{ARkitscense_url}/raw/{split}/{video_id}" + "/{}" + file_names = raw_files(video_id, raw_dataset_assets, metadata) + elif dataset == '3dod': + url_prefix = f"{ARkitscense_url}/threedod/{split}" + "/{}" + file_names = [f"{video_id}.zip", ] + elif dataset == 'upsampling': + url_prefix = f"{ARkitscense_url}/upsampling/{split}" + "/{}" + file_names = [f"{video_id}.zip", ] + else: + raise Exception(f'No such dataset = {dataset}') + + if should_download_laser_scanner_point_cloud and dataset == 'raw': + # Point clouds only available for the raw dataset + download_laser_scanner_point_clouds_for_video(video_id, metadata, download_dir) + + for file_name in file_names: + dst_path = os.path.join(dst_dir, file_name) + url = url_prefix.format(file_name) + + if not file_name.endswith('.zip') or not os.path.isdir(dst_path[:-len('.zip')]): + download_file(url, dst_path, dst_dir) + else: + print(f'WARNING: skipping download of existing zip file: {dst_path}') + if file_name.endswith('.zip') and os.path.isfile(dst_path): + unzip_file(file_name, dst_dir, keep_zip) + + if dataset == 'upsampling' and VALIDATION in dataset_splits: + val_attributes_file = "val_attributes.csv" + url = f"{ARkitscense_url}/upsampling/{VALIDATION}/{val_attributes_file}" + dst_file = os.path.join(download_dir, dataset, VALIDATION) + download_file(url, val_attributes_file, dst_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "dataset", + choices=['3dod', 'upsampling', 'raw'] + ) + + parser.add_argument( + "--split", + choices=["Training", "Validation"], + ) + + parser.add_argument( + "--video_id", + nargs='*' + ) + + parser.add_argument( + "--video_id_csv", + ) + + parser.add_argument( + "--download_dir", + default="data", + ) + + parser.add_argument( + "--keep_zip", + action='store_true' + ) + + parser.add_argument( + "--download_laser_scanner_point_cloud", + action='store_true' + ) + + parser.add_argument( + "--raw_dataset_assets", + nargs='+', + choices=default_raw_dataset_assets + ) + + args = parser.parse_args() + assert args.video_id is not None or args.video_id_csv is not None, \ + 'video_id or video_id_csv must be specified' + assert args.video_id is None or args.video_id_csv is None, \ + 'only video_id or video_id_csv must be specified' + assert args.video_id is None or args.split is not None, \ + 'given video_id the split argument must be specified' + + if args.video_id is not None: + video_ids_ = args.video_id + splits_ = splits = [args.split, ] * len(video_ids_) + elif args.video_id_csv is not None: + df = pd.read_csv(args.video_id_csv) + if args.split is not None: + df = df[df["fold"] == args.split] + video_ids_ = df["video_id"].to_list() + video_ids_ = list(map(str, video_ids_)) # Expecting video id to be a string + splits_ = df["fold"].to_list() + else: + raise Exception('No video ids specified') + + download_data(args.dataset, + video_ids_, + splits_, + args.download_dir, + args.keep_zip, + args.raw_dataset_assets, + args.download_laser_scanner_point_cloud) diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/dataset.py b/dataset/dataset.py new file mode 100644 index 0000000..4f9898d --- /dev/null +++ b/dataset/dataset.py @@ -0,0 +1,157 @@ +import cv2 +import numpy as np +import os +import json +from . import dataset_keys + +from data.ARKitScenes.depth_upsampling.dataset import ARKitScenesDataset +from data.ARKitScenes.depth_upsampling.data_utils import image_hwc_to_chw, expand_channel_dim +from promptda.utils.io_wrapper import ensure_multiple_of + +MILLIMETER_TO_METER = 1000 +WIDE = 'wide' +PATCH_SIZE = 14 # DINOv2 patch size +TARGET_H, TARGET_W = 756, 1008 + + +class MyARKitScenesDataset(ARKitScenesDataset): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # ── Image loading ───────────────────────────────────────────────────── + + def load_image(self, path, shape, is_depth, sky_direction, target_hw=None): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + img = ARKitScenesDataset.rotate_image(img, sky_direction) + + if is_depth: + img = np.asarray(img / MILLIMETER_TO_METER, np.float32) + if target_hw is not None: + # INTER_NEAREST preserves real depth values — never interpolate depth + img = cv2.resize(img, (target_hw[1], target_hw[0]), + interpolation=cv2.INTER_NEAREST) + img = expand_channel_dim(img) + else: + img = img / 255.0 + h, w = img.shape[:2] + if (h, w) != (TARGET_H, TARGET_W): + img = cv2.resize(img, (TARGET_W, TARGET_H), + interpolation=cv2.INTER_AREA) + + img = image_hwc_to_chw(np.asarray(img, np.float32)) + return img + + # ── Bounding box loading ────────────────────────────────────────────── + + @staticmethod + def load_bounding_box(box_path: str) -> np.ndarray: + """ + Load YOLOv8 boxes from JSON. + Returns: float32 array (N, 4) in xyxy, image-pixel coordinates. + """ + if not os.path.exists(box_path): + return np.zeros((0, 4), dtype=np.float32) + + with open(box_path, "r") as f: + data = json.load(f) + + boxes = data.get("boxes_xyxy_feature", []) + if not boxes: + return np.zeros((0, 4), dtype=np.float32) + + arr = np.array(boxes, dtype=np.float32) + assert arr.ndim == 2 and arr.shape[1] == 4, \ + f"Expected (N,4) boxes, got shape {arr.shape}" + return arr + + @staticmethod + def scale_boxes_to_feature_space( + boxes: np.ndarray, + img_h: int, + img_w: int, + patch_size: int = PATCH_SIZE, + ) -> np.ndarray: + """ + Scale boxes từ image pixel space → DINOv2 feature map space. + + image: (img_h, img_w) + feature map: (img_h // patch_size, img_w // patch_size) + """ + if boxes.shape[0] == 0: + return boxes + + feat_h = img_h // patch_size + feat_w = img_w // patch_size + scale_x = feat_w / img_w + scale_y = feat_h / img_h + + scaled = boxes.copy() + scaled[:, [0, 2]] *= scale_x # x1, x2 + scaled[:, [1, 3]] *= scale_y # y1, y2 + return scaled + + # ── __getitem__ ─────────────────────────────────────────────────────── + + def __getitem__(self, index: int): + video_id, sample_id, direction = self.samples[index] + sample = {dataset_keys.IDENTIFIER: str(sample_id)} + + rgb_file = os.path.join(self.dataset_folder, video_id, WIDE, sample_id) + depth_file = os.path.join(self.dataset_folder, video_id, 'highres_depth', sample_id) + apple_file = os.path.join(self.dataset_folder, video_id, 'lowres_depth', sample_id) + box_file = os.path.join( + self.dataset_folder, video_id, 'boxes', + sample_id.replace('.png', '.json') + ) + + TARGET_HW = (756, 1008) + + # Load RGB first — fixed size + color_img = self.load_image(rgb_file, self.high_res, False, direction) + _, img_h, img_w = color_img.shape # will be (3, 756, 1008) + + # Load depths resized to match RGB exactly + sample[dataset_keys.COLOR_IMG] = color_img + sample[dataset_keys.HIGH_RES_DEPTH_IMG] = self.load_image( + depth_file, self.high_res, True, direction, target_hw=TARGET_HW + ) + sample[dataset_keys.LOW_RES_DEPTH_IMG] = self.load_image( + apple_file, self.low_res, True, direction, target_hw=TARGET_HW + ) + + # Boxes: derived from fixed RGB size, no change needed + boxes_px = self.load_bounding_box(box_file) + boxes_feat = self.scale_boxes_to_feature_space(boxes_px, img_h, img_w) + sample[dataset_keys.BOUNDING_BOX] = boxes_feat + sample[dataset_keys.BOUNDING_BOX_IMAGE] = boxes_px + + if self.transform is not None: + sample[dataset_keys.COLOR_IMG] = self.transform(sample[dataset_keys.COLOR_IMG]) + + return sample + +def collate_fn(batch): + import torch + keys_to_stack = [ + dataset_keys.COLOR_IMG, + dataset_keys.HIGH_RES_DEPTH_IMG, + dataset_keys.LOW_RES_DEPTH_IMG, + ] + result = {} + + for k in keys_to_stack: + result[k] = torch.from_numpy( + np.stack([s[k] for s in batch], axis=0) + ).float() + + # Boxes: list of Tensor (N_i, 4) — N_i khác nhau + result[dataset_keys.BOUNDING_BOX] = [ + torch.from_numpy(s[dataset_keys.BOUNDING_BOX]).float() + for s in batch + ] + + result[dataset_keys.BOUNDING_BOX_IMAGE] = [torch.from_numpy(s[dataset_keys.BOUNDING_BOX_IMAGE]).float() for s in batch] + + result[dataset_keys.IDENTIFIER] = [s[dataset_keys.IDENTIFIER] for s in batch] + return result \ No newline at end of file diff --git a/dataset/dataset_keys.py b/dataset/dataset_keys.py new file mode 100644 index 0000000..7740892 --- /dev/null +++ b/dataset/dataset_keys.py @@ -0,0 +1,8 @@ +IDENTIFIER = 'identifier' +COLOR_IMG = 'color_img' +HIGH_RES_DEPTH_IMG = 'high_res_depth_img' +LOW_RES_DEPTH_IMG = 'low_res_depth_img' +PREDICTION_DEPTH_IMG = 'prediction_img' +VALID_MASK_IMG = 'valid_mask_img' +BOUNDING_BOX = 'boxes' +BOUNDING_BOX_IMAGE = 'boxes_image' \ No newline at end of file diff --git a/dataset/utils/__init__.py b/dataset/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inference_baseline.py b/inference_baseline.py new file mode 100644 index 0000000..f617915 --- /dev/null +++ b/inference_baseline.py @@ -0,0 +1,195 @@ +""" +Baseline inference script for PromptDA. + +Usage: + python inference_baseline.py \ + --data_root data/ARKitScenes/data/upsampling \ + --encoder vitl \ + --pretrained_path /path/to/model.ckpt \ + --output_dir results/baseline +""" + +import argparse +import os +import sys + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch.nn.functional as F + +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from dataset.dataset import MyARKitScenesDataset +from promptda.promptda_baseline import PromptDA +from promptda.utils.logger import Log +from training.metrics import compute_depth_metrics, aggregate_metrics + + +# --------------------------------------------------------------------------- # +# Args +# --------------------------------------------------------------------------- # + +def parse_args(): + p = argparse.ArgumentParser(description="PromptDA Baseline Inference") + + p.add_argument("--data_root", type=str, default="data/ARKitScenes/data/upsampling") + p.add_argument("--split", type=str, default="val", choices=["train", "val"]) + p.add_argument("--max_samples", type=int, default=None) + p.add_argument("--num_workers", type=int, default=4) + + p.add_argument("--encoder", type=str, default="vitl", + choices=["vits", "vitb", "vitl"]) + p.add_argument("--pretrained_path", type=str, + default="depth-anything/prompt-depth-anything-vitl") + + p.add_argument("--max_size", type=int, default=1008, + help="Max longer-side of image (floored to multiple of 14). " + "Default 1008 = 72x14.") + + p.add_argument("--output_dir", type=str, default="results/baseline") + p.add_argument("--save_depth", action="store_true") + p.add_argument("--batch_size", type=int, default=1) + + return p.parse_args() + + +# --------------------------------------------------------------------------- # +# Per-sample predict wrapper +# --------------------------------------------------------------------------- # + +def run_batch(model, image, prompt): + """ + image: (B, 3, H, W) + prompt: (B, 1, h, w) + returns: pred: (B, 1, H, W) + """ + preds = [] + for i in range(image.shape[0]): + depth_i = model.predict( + image[i].unsqueeze(0), # (1, 3, H, W) + prompt[i].unsqueeze(0), # (1, 1, h, w) + ) + + # Normalize về đúng (1, H, W) + depth_i = depth_i.squeeze() # loại bỏ TẤT CẢ dim=1 -> (H, W) + depth_i = depth_i.unsqueeze(0) # thêm channel dim -> (1, H, W) + + preds.append(depth_i) + + return torch.stack(preds, dim=0) # (B, 1, H, W) + + +# --------------------------------------------------------------------------- # +# Main +# --------------------------------------------------------------------------- # + +def main(): + args = parse_args() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + Log.info(f"Device : {device}") + Log.info(f"Encoder : {args.encoder}") + Log.info(f"Split : {args.split}") + Log.info(f"Batch size : {args.batch_size}") + Log.info(f"Max img size : {args.max_size} (multiple of 14)") + Log.info(f"Pretrained : {args.pretrained_path}") + + os.makedirs(args.output_dir, exist_ok=True) + if args.save_depth: + depth_dir = os.path.join(args.output_dir, "depth_maps") + os.makedirs(depth_dir, exist_ok=True) + + # ── Transform ───────────────────────────────────────────────────────── # + # ImageTransform: resize IMAGE tensor (C,H,W) -> multiple of 14 + # Dataset goi: sample["color_img"] = self.transform(sample["color_img"]) + + # ── Dataset ─────────────────────────────────────────────────────────── # + Log.info(f"Loading '{args.split}' dataset from: {args.data_root}") + + dataset = MyARKitScenesDataset( + root=args.data_root, + split=args.split, + ) + + if args.max_samples is not None: + dataset = torch.utils.data.Subset( + dataset, range(min(args.max_samples, len(dataset))) + ) + + loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=(device.type == "cuda"), + ) + Log.info(f"Dataset size : {len(dataset)} | Batches: {len(loader)}") + + # ── Model ────────────────────────────────────────────────────────────── # + Log.info("Loading PromptDA model...") + model = PromptDA.from_pretrained( + pretrained_model_name_or_path=args.pretrained_path + ).to(device).eval() + Log.info("Model ready.") + + # ── Inference ────────────────────────────────────────────────────────── # + all_metrics = [] + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(loader, desc="Inference")): + + image = batch["color_img"].to(device) # (B, 3, H, W) + depth_gt = batch["high_res_depth_img"].to(device) # (B, 1, H, W) + prompt = batch["low_res_depth_img"].to(device) # (B, 1, h, w) + + bouding_boxes = batch["bounding_box"] # list of (B, N, 4) + # bouding_boxes = [b.to(device) for b in bouding_boxes] + print(bouding_boxes) + + pred = run_batch(model, image, prompt) # (B, 1, H, W) + + if pred.shape[-2:] != depth_gt.shape[-2:]: + pred = torch.nn.functional.interpolate( + pred, + size=depth_gt.shape[-2:], + mode="bilinear", + align_corners=False, + ) + + all_metrics.append(compute_depth_metrics(pred, depth_gt)) + + if args.save_depth: + for i in range(pred.shape[0]): + fname = f"batch_{batch_idx:04d}_sample_{i:02d}.npy" + np.save(os.path.join(depth_dir, fname), pred[i, 0].cpu().numpy()) + + # ── Results ──────────────────────────────────────────────────────────── # + agg = aggregate_metrics(all_metrics) + + Log.info("=" * 60) + Log.info(f"Baseline Results [{args.split}]") + Log.info("=" * 60) + for k, v in sorted(agg.items()): + Log.info(f" {k:20s}: {v:.6f}") + Log.info("=" * 60) + + txt_path = os.path.join(args.output_dir, "baseline_metrics.txt") + with open(txt_path, "w") as f: + f.write(f"Baseline Inference Results ({args.split})\n") + f.write("=" * 60 + "\n") + for k, v in sorted(agg.items()): + f.write(f" {k:20s}: {v:.6f}\n") + f.write("=" * 60 + "\n") + Log.info(f"Saved -> {txt_path}") + + npz_path = os.path.join(args.output_dir, "baseline_metrics.npz") + np.savez(npz_path, **agg) + Log.info(f"Saved -> {npz_path}") + + return agg + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inference_mlf.py b/inference_mlf.py new file mode 100644 index 0000000..e69de29 diff --git a/instruction.md b/instruction.md new file mode 100644 index 0000000..3f99f2b --- /dev/null +++ b/instruction.md @@ -0,0 +1,137 @@ +# Copilot Implementation Guide: Masked Local Fusion (MLF) for Prompt-Depth-Anything + +This document is a natural-language instruction contract for Copilot/Cursor. Do not generate code until you have read every section. Follow the sections in order. + +--- + +## 0. What You Are Building + +The goal is to sharpen depth predictions inside object regions by fusing spatially-precise local features back into the global backbone feature map. A binary spatial mask derived from bounding boxes ensures that this fusion only affects pixels inside detected objects — the background is never modified. + +The enhancement follows this logic: take the global feature map from the backbone, extract object-region features from it using RoI Align, project those features to match the global channel dimension, zero out everything outside the bounding boxes using a mask, then add the result back to the global feature map as a residual. + +There are three rules that must never be violated. First, never run the backbone a second time on cropped images — all local features must come from the already-computed global feature map. Second, the spatial mask must strictly zero out any contribution outside the bounding boxes. Third, the fusion must always be additive (residual), never replacing the original features. + +--- + +## 1. Stage 0 — Object Detection (Prerequisite) + +Before any depth feature fusion can happen, the pipeline needs bounding boxes. This is an entirely separate stage that must be added before the depth model runs. + +### Recommended approach: Offline detection with YOLOv8 + +Use a pretrained YOLOv8 model to pre-compute bounding boxes for every image in the dataset and save them alongside the images (for example, as `.json` files in the same folder structure). During training and inference, the dataloader reads these saved boxes instead of running the detector live. This keeps VRAM usage low and training speed high. + +### What Copilot should implement for this stage + +- Add a standalone detection script (not part of the model) that accepts a dataset directory, runs YOLOv8 on every image, and saves the bounding boxes in feature-map coordinate space to a sidecar file per image. +- The boxes must be saved in xyxy format and already scaled to the feature map resolution (not the original image resolution), because they will be passed directly to RoI Align with `spatial_scale=1.0`. +- The scaling factor is: feature map height divided by image height (and same for width). For a ViT-B/14 backbone processing a 518×518 image, the feature map is typically 37×37, so the scale factor is approximately 37/518. +- In the dataloader, load the sidecar box file that corresponds to each image. If no sidecar file exists for an image, pass an empty list of boxes — the MLF module must handle this gracefully and simply return the global features unchanged. + +### Why not online detection + +Running a detector inside the training loop doubles the GPU workload and makes batching harder. Offline detection is the standard approach in two-stage pipelines and is strongly recommended here. + +--- + +## 2. Stage 1 — Expose Intermediate Features from the Backbone + +The MLF module needs access to an intermediate feature map from the DINOv2/ViT backbone, not just the final depth output. + +### What Copilot should do + +- In `depth_anything_v2/dpt.py`, find the forward method of the DPT model or head where the backbone produces its layered outputs. +- Add an optional flag to the forward method (for example `return_intermediate`) that, when enabled, also returns one of the intermediate layer outputs alongside the final depth prediction. +- The recommended intermediate layer to expose is the deepest one (layer 4 in a 4-stage DPT), as it contains the most semantic information. This will serve as `F_global` for the MLF module. +- Do not change any of the existing DPT fusion logic. This is a non-destructive addition only. + +--- + +## 3. Stage 2 — The MaskedLocalFusion Module + +Create a new file at `models/masked_fusion.py` and define a PyTorch module called `MaskedLocalFusion` inside it. + +### What the module must do, step by step + +**Step A — Feature Extraction via RoI Align** +Use `torchvision.ops.roi_align` to extract fixed-size feature patches from `F_global` at the locations specified by the bounding boxes. The boxes are already in feature-map coordinates, so `spatial_scale` should be `1.0`. The sampling ratio must be consistent with the backbone's patch size — for ViT-B/14 (patch size 14), a sampling ratio of 2 is appropriate. The output will be a collection of small fixed-size patches, one per detected object across the batch. + +**Step B — Spatial Reconstruction** +Create a zero-initialized tensor with the same shape as `F_global`. Scatter each extracted patch back into the spatial location corresponding to its bounding box by resizing the patch to fill the box region. Then apply a 1×1 convolution (the Projector) to the entire reconstructed tensor to align the channel dimension and learn a fusion weight. + +**Step C — Masked Fusion** +Build a binary spatial mask with the same height and width as `F_global`. Every pixel that falls inside any bounding box gets the value 1.0; everything else stays 0.0. Multiply the projected local features element-wise by this mask. Then add the masked result to `F_global` as a residual connection. The output is `F_enhanced`, which has the same shape as `F_global`. + +### Module configuration + +The module should accept three parameters at initialization: the number of input channels (matching the backbone's intermediate feature channels), the RoI Align output size (7 is a sensible default), and the sampling ratio. + +### Edge cases Copilot must handle + +- If a batch image has no detected objects (empty box list), the module must return the original `F_global` for that image unchanged, with no error. +- Bounding box coordinates that extend beyond the feature map boundaries must be clamped before scattering or masking. + +--- + +## 4. Stage 3 — Integration into the Prompt Encoder + +Edit `prompt_da/models/prompt_encoder.py` to wire the MLF module into the existing pipeline. + +### What Copilot should do + +- Import `MaskedLocalFusion` from `models/masked_fusion`. +- Instantiate it inside `__init__` using the channel dimension of the backbone's intermediate feature map. +- In the `forward` method, call the backbone with `return_intermediate=True` to get both the depth prediction and `F_global`. +- If bounding boxes are provided, pass `F_global` through `self.mlf` to get `F_enhanced`. If no boxes are provided, use `F_global` as-is. +- Use the resulting feature map (enhanced or not) for all subsequent steps: point-prompt embedding and the DPT decoder input. +- The `forward` method signature should accept an optional `boxes` argument (list of tensors, one per image) that defaults to `None`. + +### Injection point + +The MLF call must happen between backbone feature extraction and the DPT decoder stages. It must not be placed inside the decoder or before the backbone. + +--- + +## 5. Training Strategy + +### Phase 1 — Freeze backbone, train MLF projector only +In the first training phase, freeze all backbone parameters. Only the MLF projector (the 1×1 convolution) should have gradients enabled. Train for a small number of epochs to let the projector learn a stable fusion weight before the backbone adapts. + +### Phase 2 — Unfreeze and fine-tune jointly +Unfreeze the backbone and train the full model jointly with a lower learning rate for the backbone than for the MLF module and decoder. + +### Loss function +Do not change the existing Scale-and-Shift Invariant depth loss used by Prompt-DA. MLF is a feature-level change and does not require a new loss term. + +--- + +## 6. VRAM Constraints — Copilot Must Follow All of These + +- Never crop the original image and re-run the backbone. Extract features from the existing `F_global` only. +- The zero-initialized reconstruction canvas must be created once per forward call, not inside any inner loop. +- Use bilinear interpolation when resizing patches back to box size. Do not use bicubic. +- The spatial mask must be a float tensor, not a boolean tensor, to avoid implicit type promotion during multiplication. +- Confirm that `torch.no_grad()` wraps the backbone during inference, including the `return_intermediate` path. + +--- + +## 7. Validation + +After implementation, write a single unit test in `tests/test_masked_fusion.py` that does the following without any code dependencies on the rest of the project. Create a dummy zero-valued global feature map and a single bounding box covering only the top-left quadrant. Run `MaskedLocalFusion` on it. Assert that every value outside the bounding box in the output remains exactly zero. This confirms the mask is working and no feature is leaking into the background. + +--- + +## 8. File Summary + +| File | Action | Purpose | +|---|---|---| +| `scripts/precompute_boxes.py` | Create | Offline YOLOv8 detection; saves boxes per image as sidecar files | +| `models/masked_fusion.py` | Create | MaskedLocalFusion module | +| `depth_anything_v2/dpt.py` | Edit | Expose intermediate feature map via optional flag | +| `prompt_da/models/prompt_encoder.py` | Edit | Instantiate and call MLF between backbone and decoder | +| `tests/test_masked_fusion.py` | Create | Background-leaking sanity check | + +--- + +*Masked Local Fusion Implementation Contract v1.1 — instruction-only* \ No newline at end of file diff --git a/promptda/__init__.py b/promptda/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/promptda/model/__init__.py b/promptda/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/promptda/model/dpt.py b/promptda/model/dpt.py index 8093297..71add06 100644 --- a/promptda/model/dpt.py +++ b/promptda/model/dpt.py @@ -104,7 +104,7 @@ def __init__(self, act_func, ) - def forward(self, out_features, patch_h, patch_w, prompt_depth=None): + def forward(self, out_features, patch_h, patch_w, prompt_depth=None, return_intermediate=False): out = [] for i, x in enumerate(out_features): if self.use_clstoken: @@ -142,4 +142,6 @@ def forward(self, out_features, patch_h, patch_w, prompt_depth=None): out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) out = self.scratch.output_conv2(out_feat) + if return_intermediate: + return out, layer_4_rn return out diff --git a/promptda/model/masked_fusion.py b/promptda/model/masked_fusion.py new file mode 100644 index 0000000..9ebfc9e --- /dev/null +++ b/promptda/model/masked_fusion.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + + +class MaskedLocalFusion(nn.Module): + def __init__( + self, + in_channels: int, + roi_output_size: int = 7, + sampling_ratio: int = 2, + ): + super().__init__() + self.roi_output_size = roi_output_size + self.sampling_ratio = sampling_ratio + self.projector = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def _clamp_boxes(self, boxes: torch.Tensor, height: int, width: int) -> torch.Tensor: + if boxes.numel() == 0: + return boxes + + clamped = boxes.clone() + clamped[:, [0, 2]] = clamped[:, [0, 2]].clamp(min=0, max=width) + clamped[:, [1, 3]] = clamped[:, [1, 3]].clamp(min=0, max=height) + + # Ensure non-zero extent. + clamped[:, 2] = torch.maximum(clamped[:, 2], clamped[:, 0] + 1.0) + clamped[:, 3] = torch.maximum(clamped[:, 3], clamped[:, 1] + 1.0) + clamped[:, 2] = clamped[:, 2].clamp(max=width) + clamped[:, 3] = clamped[:, 3].clamp(max=height) + return clamped + + def forward( + self, + f_global: torch.Tensor, + boxes: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + if boxes is None or len(boxes) == 0: + # Keep an autograd path through MLF parameters even when no boxes are provided. + # Numerical output is unchanged. + return f_global + 0.0 * self.projector(f_global) + + batch_size, _, height, width = f_global.shape + if len(boxes) != batch_size: + raise ValueError(f"Expected {batch_size} box tensors, got {len(boxes)}") + + if all(b.numel() == 0 for b in boxes): + # Keep an autograd path through MLF parameters even when all box tensors are empty. + # Numerical output is unchanged. + return f_global + 0.0 * self.projector(f_global) + + device = f_global.device + dtype = f_global.dtype + + canvas = torch.zeros_like(f_global) + mask = torch.zeros((batch_size, 1, height, width), device=device, dtype=dtype) + + rois = [] + roi_index_map = [] + clamped_boxes_per_image: List[torch.Tensor] = [] + + for b_idx, b in enumerate(boxes): + if b.numel() == 0: + clamped = torch.empty((0, 4), device=device, dtype=dtype) + clamped_boxes_per_image.append(clamped) + continue + + b = b.to(device=device, dtype=dtype) + clamped = self._clamp_boxes(b, height=height, width=width) + clamped_boxes_per_image.append(clamped) + for i in range(clamped.shape[0]): + rois.append( + torch.tensor( + [ + float(b_idx), + clamped[i, 0].item(), + clamped[i, 1].item(), + clamped[i, 2].item(), + clamped[i, 3].item(), + ], + device=device, + dtype=dtype, + ) + ) + roi_index_map.append((b_idx, i)) + + if len(rois) == 0: + # Safe fallback that preserves gradient flow to MLF parameters. + return f_global + 0.0 * self.projector(f_global) + + rois_tensor = torch.stack(rois, dim=0) + roi_patches = roi_align( + input=f_global, + boxes=rois_tensor, + output_size=(self.roi_output_size, self.roi_output_size), + spatial_scale=1.0, + sampling_ratio=self.sampling_ratio, + aligned=True, + ) + + for patch_idx, (b_idx, local_idx) in enumerate(roi_index_map): + box = clamped_boxes_per_image[b_idx][local_idx] + x1 = int(torch.floor(box[0]).item()) + y1 = int(torch.floor(box[1]).item()) + x2 = int(torch.ceil(box[2]).item()) + y2 = int(torch.ceil(box[3]).item()) + + x1 = max(0, min(x1, width - 1)) + y1 = max(0, min(y1, height - 1)) + x2 = max(x1 + 1, min(x2, width)) + y2 = max(y1 + 1, min(y2, height)) + + target_h = y2 - y1 + target_w = x2 - x1 + resized_patch = F.interpolate( + roi_patches[patch_idx : patch_idx + 1], + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) + canvas[b_idx : b_idx + 1, :, y1:y2, x1:x2] = ( + canvas[b_idx : b_idx + 1, :, y1:y2, x1:x2] + resized_patch + ) + mask[b_idx : b_idx + 1, :, y1:y2, x1:x2] = 1.0 + + projected = self.projector(canvas) + f_enhanced = f_global + projected * mask + return f_enhanced diff --git a/promptda/promptda.py b/promptda/promptda.py index 5704f91..2515b7b 100644 --- a/promptda/promptda.py +++ b/promptda/promptda.py @@ -1,121 +1,252 @@ +""" +promptda/promptda.py + +Two modes: + Baseline (use_mlf=False): freeze all, zero-shot eval + Experiment (use_mlf=True): freeze DINOv2, train MLF projector +""" + +import os +from pathlib import Path +from typing import Optional + import torch import torch.nn as nn -from promptda.model.dpt import DPTHead +from huggingface_hub import hf_hub_download + from promptda.model.config import model_configs +from promptda.model.dpt import DPTHead +from promptda.model.masked_fusion import MaskedLocalFusion from promptda.utils.logger import Log -import os -from pathlib import Path -from huggingface_hub import hf_hub_download class PromptDA(nn.Module): - patch_size = 14 # patch size of the pretrained dinov2 model - use_bn = False + + patch_size = 14 + use_bn = False use_clstoken = False - output_act = 'sigmoid' + output_act = 'sigmoid' + + HF_REPOS = { + "vits": "depth-anything/prompt-depth-anything-vits", + "vitb": "depth-anything/prompt-depth-anything-vitb", + "vitl": "depth-anything/prompt-depth-anything-vitl", + } - def __init__(self, - encoder='vitl', - ckpt_path='checkpoints/promptda_vitl.ckpt'): + def __init__( + self, + encoder: str = 'vitl', + ckpt_path: Optional[str] = None, + use_mlf: bool = True, + ): super().__init__() - model_config = model_configs[encoder] + self.encoder = encoder + self.use_mlf = use_mlf + model_config = model_configs[encoder] - self.encoder = encoder - self.model_config = model_config - module_path = Path(__file__) # From anywhere else: module_path = Path(inspect.getfile(PromptDA)) - package_base_dir = str(Path(*module_path.parts[:-2])) # extract path to PromptDA module, then resolve to repo base dir for dinov2 load - self.pretrained = torch.hub.load( + # ── Backbone: DINOv2 ────────────────────────────────────────────── + module_path = Path(__file__) + package_base_dir = str(Path(*module_path.parts[:-2])) + self.pretrained = torch.hub.load( f'{package_base_dir}/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', - pretrained=False) + pretrained=False, + ) dim = self.pretrained.blocks[0].attn.qkv.in_features - self.depth_head = DPTHead(nclass=1, - in_channels=dim, - features=model_config['features'], - out_channels=model_config['out_channels'], - use_bn=self.use_bn, - use_clstoken=self.use_clstoken, - output_act=self.output_act) - - # mean and std of the pretrained dinov2 model + + # ── Decoder: DPT head ───────────────────────────────────────────── + self.depth_head = DPTHead( + nclass=1, + in_channels=dim, + features=model_config['features'], + out_channels=model_config['out_channels'], + use_bn=self.use_bn, + use_clstoken=self.use_clstoken, + output_act=self.output_act, + ) + + # ── MLF (always init to keep state_dict shape stable) ───────────── + self.mlf = MaskedLocalFusion( + in_channels=dim, + roi_output_size=7, + sampling_ratio=2, + ) + + # ── ImageNet normalisation stats ────────────────────────────────── self.register_buffer('_mean', torch.tensor( [0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('_std', torch.tensor( [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) - self.load_checkpoint(ckpt_path) - + # ── Load pretrained weights ─────────────────────────────────────── + if ckpt_path is not None: + self._load_pretrained_weights(ckpt_path) + + # ── Freeze strategy ─────────────────────────────────────────────── + self._apply_freeze_strategy() + + # ── Freeze strategy ─────────────────────────────────────────────────── + + def _apply_freeze_strategy(self): + if not self.use_mlf: + for p in self.parameters(): + p.requires_grad = False + else: + # Freeze DINOv2 + DPT head + for p in self.pretrained.parameters(): + p.requires_grad = False + for p in self.depth_head.parameters(): + p.requires_grad = False # ← freeze hẳn, không cần lr=0 trick + + # Chỉ train MLF + for p in self.mlf.parameters(): + p.requires_grad = True + + trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + total = sum(p.numel() for p in self.parameters()) + Log.info(f"Trainable: {trainable:,} / {total:,} params (MLF only)") + + # ── Constructors ────────────────────────────────────────────────────── + @classmethod - def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs): - """ - Load a model from a checkpoint file. - ### Parameters: - - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. - - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. - - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. - ### Returns: - - A new instance of `MoGe` with the parameters loaded from the checkpoint. - """ - ckpt_path = None + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[str] = None, + encoder: str = 'vitl', + use_mlf: bool = True, + **hf_kwargs, + ) -> "PromptDA": + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = cls.HF_REPOS[encoder] + if Path(pretrained_model_name_or_path).exists(): ckpt_path = pretrained_model_name_or_path else: - cached_checkpoint_path = hf_hub_download( + ckpt_path = hf_hub_download( repo_id=pretrained_model_name_or_path, repo_type="model", filename="model.ckpt", - **hf_kwargs + **hf_kwargs, ) - ckpt_path = cached_checkpoint_path - # model_config = checkpoint['model_config'] - # if model_kwargs is not None: - # model_config.update(model_kwargs) - if model_kwargs is None: - model_kwargs = {} - model_kwargs.update({'ckpt_path': ckpt_path}) - model = cls(**model_kwargs) - return model - - def load_checkpoint(self, ckpt_path): - if os.path.exists(ckpt_path): - Log.info(f'Loading checkpoint from {ckpt_path}') - checkpoint = torch.load(ckpt_path, map_location='cpu') - self.load_state_dict( - {k[9:]: v for k, v in checkpoint['state_dict'].items()}) + + return cls(encoder=encoder, ckpt_path=ckpt_path, use_mlf=use_mlf) + + # ── Checkpoint loading ──────────────────────────────────────────────── + + def _load_pretrained_weights(self, ckpt_path: str): + if not os.path.exists(ckpt_path): + Log.warn(f"Checkpoint does not exist: {ckpt_path}") + return + + Log.info(f'Loading checkpoint: {ckpt_path}') + checkpoint = torch.load(ckpt_path, map_location='cpu') + + # ── Identical to original: expects 'state_dict' with 'model.' prefix + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] else: - Log.warn(f'Checkpoint {ckpt_path} not found') + state_dict = checkpoint + + # Strip 9-char prefix 'model.xxx' exactly as original load_checkpoint does + # but fall back to generic stripping for flexibility + first_key = next(iter(state_dict)) + if first_key.startswith('model.'): + state_dict = {k[6:]: v for k, v in state_dict.items()} + Log.info("Stripped prefix: 'model.'") + elif '.' in first_key: + prefix = first_key.split('.')[0] + '.' + if all(k.startswith(prefix) for k in state_dict): + state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} + Log.info(f"Stripped prefix: '{prefix}'") + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + + mlf_missing = [k for k in missing if k.startswith('mlf')] + critical_missing = [k for k in missing if not k.startswith('mlf')] + + if critical_missing: + Log.warn(f'Missing keys (unexpected): {critical_missing}') + if unexpected: + Log.warn(f'Unexpected keys: {unexpected}') + Log.info(f'MLF keys random init (expected): {len(mlf_missing)}') + + # ── Forward ─────────────────────────────────────────────────────────── - def forward(self, x, prompt_depth=None): + def forward( + self, + x: torch.Tensor, + prompt_depth: torch.Tensor, + boxes: Optional[list[torch.Tensor]] = None, + return_intermediate: bool = False, + ): assert prompt_depth is not None, 'prompt_depth is required' + prompt_depth, min_val, max_val = self.normalize(prompt_depth) - h, w = x.shape[-2:] + + h, w = x.shape[-2:] + patch_h = h // self.patch_size + patch_w = w // self.patch_size + + # ── Backbone forward (identical to original) ────────────────────── features = self.pretrained.get_intermediate_layers( - (x - self._mean) / self._std, self.model_config['layer_idxs'], - return_class_token=True) - patch_h, patch_w = h // self.patch_size, w // self.patch_size + (x - self._mean) / self._std, + self.model_config['layer_idxs'], + return_class_token=True, + ) + + if self.use_mlf and boxes is not None: + # Convert tuple→list so we can splice the deepest feature map + features = list(features) + + deepest_tokens, cls_token = features[-1] # (B, N, C) + f_global = deepest_tokens.permute(0, 2, 1).reshape( + deepest_tokens.shape[0], + deepest_tokens.shape[-1], + patch_h, patch_w, + ) # (B, C, pH, pW) + + f_enhanced = self.mlf(f_global, boxes) # (B, C, pH, pW) + enhanced_tokens = f_enhanced.flatten(2).permute(0, 2, 1).contiguous() + features[-1] = (enhanced_tokens, cls_token) + + # ── DPT head (identical call signature to original) ─────────────── depth = self.depth_head(features, patch_h, patch_w, prompt_depth) depth = self.denormalize(depth, min_val, max_val) + + if return_intermediate and self.use_mlf and boxes is not None: + return depth, f_enhanced return depth @torch.no_grad() - def predict(self, - image: torch.Tensor, - prompt_depth: torch.Tensor): - return self.forward(image, prompt_depth) - - def normalize(self, - prompt_depth: torch.Tensor): - B, C, H, W = prompt_depth.shape + def predict( + self, + image: torch.Tensor, + prompt_depth: torch.Tensor, + boxes: Optional[list[torch.Tensor]] = None, + ) -> torch.Tensor: + return self.forward(image, prompt_depth, boxes=boxes) + + # ── Helpers ─────────────────────────────────────────────────────────── + + @property + def model_config(self): + return model_configs[self.encoder] + + def normalize(self, prompt_depth: torch.Tensor): + """Identical to original normalize().""" + B = prompt_depth.shape[0] min_val = torch.quantile( - prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None] + prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True + )[:, :, None, None] max_val = torch.quantile( - prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None] + prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True + )[:, :, None, None] prompt_depth = (prompt_depth - min_val) / (max_val - min_val) return prompt_depth, min_val, max_val - def denormalize(self, - depth: torch.Tensor, - min_val: torch.Tensor, - max_val: torch.Tensor): - return depth * (max_val - min_val) + min_val + def denormalize(self, depth, min_val, max_val): + """Identical to original denormalize().""" + return depth * (max_val - min_val) + min_val \ No newline at end of file diff --git a/promptda/promptda_baseline.py b/promptda/promptda_baseline.py new file mode 100644 index 0000000..d50d94d --- /dev/null +++ b/promptda/promptda_baseline.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +from promptda.model.dpt import DPTHead +from promptda.model.config import model_configs +from promptda.utils.logger import Log +import os +from pathlib import Path +from huggingface_hub import hf_hub_download + + +class PromptDA(nn.Module): + patch_size = 14 # patch size of the pretrained dinov2 model + use_bn = False + use_clstoken = False + output_act = 'sigmoid' + + def __init__(self, + encoder='vitl', + ckpt_path='checkpoints/promptda_vitl.ckpt'): + super().__init__() + model_config = model_configs[encoder] + + self.encoder = encoder + self.model_config = model_config + module_path = Path(__file__) # From anywhere else: module_path = Path(inspect.getfile(PromptDA)) + package_base_dir = str(Path(*module_path.parts[:-2])) # extract path to PromptDA module, then resolve to repo base dir for dinov2 load + self.pretrained = torch.hub.load( + f'{package_base_dir}/torchhub/facebookresearch_dinov2_main', + 'dinov2_{:}14'.format(encoder), + source='local', + pretrained=False) + dim = self.pretrained.blocks[0].attn.qkv.in_features + self.depth_head = DPTHead(nclass=1, + in_channels=dim, + features=model_config['features'], + out_channels=model_config['out_channels'], + use_bn=self.use_bn, + use_clstoken=self.use_clstoken, + output_act=self.output_act) + + # mean and std of the pretrained dinov2 model + self.register_buffer('_mean', torch.tensor( + [0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('_std', torch.tensor( + [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + self.load_checkpoint(ckpt_path) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs): + """ + Load a model from a checkpoint file. + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + ckpt_path = None + if Path(pretrained_model_name_or_path).exists(): + ckpt_path = pretrained_model_name_or_path + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.ckpt", + **hf_kwargs + ) + ckpt_path = cached_checkpoint_path + # model_config = checkpoint['model_config'] + # if model_kwargs is not None: + # model_config.update(model_kwargs) + if model_kwargs is None: + model_kwargs = {} + model_kwargs.update({'ckpt_path': ckpt_path}) + model = cls(**model_kwargs) + return model + + def load_checkpoint(self, ckpt_path): + if os.path.exists(ckpt_path): + Log.info(f'Loading checkpoint from {ckpt_path}') + checkpoint = torch.load(ckpt_path, map_location='cpu') + self.load_state_dict( + {k[9:]: v for k, v in checkpoint['state_dict'].items()}) + else: + Log.warn(f'Checkpoint {ckpt_path} not found') + + def forward(self, x, prompt_depth=None): + assert prompt_depth is not None, 'prompt_depth is required' + prompt_depth, min_val, max_val = self.normalize(prompt_depth) + h, w = x.shape[-2:] + features = self.pretrained.get_intermediate_layers( + (x - self._mean) / self._std, self.model_config['layer_idxs'], + return_class_token=True) + patch_h, patch_w = h // self.patch_size, w // self.patch_size + depth = self.depth_head(features, patch_h, patch_w, prompt_depth) + depth = self.denormalize(depth, min_val, max_val) + return depth + + @torch.no_grad() + def predict(self, + image: torch.Tensor, + prompt_depth: torch.Tensor): + return self.forward(image, prompt_depth) + + def normalize(self, + prompt_depth: torch.Tensor): + B, C, H, W = prompt_depth.shape + min_val = torch.quantile( + prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None] + max_val = torch.quantile( + prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None] + prompt_depth = (prompt_depth - min_val) / (max_val - min_val) + return prompt_depth, min_val, max_val + + def denormalize(self, + depth: torch.Tensor, + min_val: torch.Tensor, + max_val: torch.Tensor): + return depth * (max_val - min_val) + min_val \ No newline at end of file diff --git a/promptda/scripts/precompute_boxes.py b/promptda/scripts/precompute_boxes.py new file mode 100644 index 0000000..349611a --- /dev/null +++ b/promptda/scripts/precompute_boxes.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Optional + +import tyro +from PIL import Image +from tqdm.auto import tqdm + + +IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} + + +def _find_images(input_dir: Path) -> list[Path]: + return sorted([p for p in input_dir.rglob("*") if p.suffix.lower() in IMAGE_SUFFIXES]) + + +def _to_feature_boxes( + boxes_xyxy: list[list[float]], + image_w: int, + image_h: int, + feat_w: int, + feat_h: int, +) -> list[list[float]]: + if len(boxes_xyxy) == 0: + return [] + + scale_x = feat_w / float(image_w) + scale_y = feat_h / float(image_h) + scaled = [] + for x1, y1, x2, y2 in boxes_xyxy: + sx1 = max(0.0, min(float(feat_w), x1 * scale_x)) + sy1 = max(0.0, min(float(feat_h), y1 * scale_y)) + sx2 = max(0.0, min(float(feat_w), x2 * scale_x)) + sy2 = max(0.0, min(float(feat_h), y2 * scale_y)) + if sx2 <= sx1 or sy2 <= sy1: + continue + scaled.append([sx1, sy1, sx2, sy2]) + return scaled + + +def main( + input_dir: str, + output_dir: Optional[str] = None, + model_name: str = "yolov8n.pt", + conf: float = 0.25, + iou: float = 0.7, + imgsz: int = 640, + patch_size: int = 14, + feature_height: Optional[int] = None, + feature_width: Optional[int] = None, + sidecar_suffix: str = ".json", + device: str = "cuda", +): + """Precompute YOLO boxes and save sidecar JSONs in feature-map coordinates (xyxy).""" + try: + from ultralytics import YOLO + except ImportError as exc: + raise RuntimeError( + "ultralytics is required for offline detection. Install it with: pip install ultralytics" + ) from exc + + input_path = Path(input_dir) + if not input_path.exists(): + raise FileNotFoundError(f"Input directory not found: {input_path}") + + output_path = Path(output_dir) if output_dir is not None else input_path + output_path.mkdir(parents=True, exist_ok=True) + + image_paths = _find_images(input_path) + detector = YOLO(model_name) + + for image_path in tqdm(image_paths, desc="Precomputing boxes"): + with Image.open(image_path) as im: + image_w, image_h = im.size + + feat_h = int(feature_height) if feature_height is not None else max(1, image_h // patch_size) + feat_w = int(feature_width) if feature_width is not None else max(1, image_w // patch_size) + + results = detector.predict( + source=str(image_path), + verbose=False, + conf=conf, + iou=iou, + imgsz=imgsz, + device=device, + ) + result = results[0] + raw_boxes = result.boxes.xyxy.detach().cpu().tolist() if result.boxes is not None else [] + + feature_boxes = _to_feature_boxes(raw_boxes, image_w, image_h, feat_w, feat_h) + + rel = image_path.relative_to(input_path) + + # We want the 'boxes' folder at the same level as the folder directly containing the image. + if str(rel.parent) == '.': + # Example: input_path is .../41048190/wide, image is img.jpg + # target => .../41048190/boxes + boxes_dir = output_path.parent / "boxes" + else: + # Example: input_path is .../41048190, image is wide/img.jpg + # rel.parent is 'wide', rel.parent.parent is '.' + # target => .../41048190/boxes + boxes_dir = output_path / rel.parent.parent / "boxes" + + sidecar_path = (boxes_dir / rel.name).with_suffix(sidecar_suffix) + sidecar_path.parent.mkdir(parents=True, exist_ok=True) + + payload = { + "image_path": str(rel), + "image_size": [image_h, image_w], + "feature_size": [feat_h, feat_w], + "boxes_xyxy_feature": feature_boxes, + } + sidecar_path.write_text(json.dumps(payload, indent=2)) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/promptda/utils/__init__.py b/promptda/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 81b3028..d9b8d0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,42 @@ -# python version >= 3.9 <= 3.11 -# torch -torch==2.0.1 -torchvision==0.15.2 -torchaudio==2.0.2 -xformers==0.0.22 -# +# --- PyTorch & Xformers (Pre-built Binaries) --- +--extra-index-url https://download.pytorch.org/whl/cu121 +torchvision==0.21.0 +torchaudio==2.6.0 + +# --- Core Frameworks --- lightning==2.1.3 -imageio>=2.33.1 +ultralytics==8.3.0 + +# pip install https://github.com/isl-org/Open3D/releases/download/v0.18.0/open3d-0.18.0-cp312-cp312-manylinux_2_27_x86_64.whl + +# --- Computer Vision & Image Processing --- +opencv-python==4.9.0.80 Pillow>=10.1.0 +imageio>=2.33.1 imageio-ffmpeg -einops -tqdm -ipdb -# Diffusion -transformers -# ray -termcolor +scikit-image +matplotlib +trimesh +# open3d + +# --- Data & Operations --- numpy==1.26.4 -opencv-python==4.9.0.80 scipy -matplotlib h5py +einops +pandas +tqdm tyro==0.9.2 -# open3d +termcolor +ipdb +ninja +wandb -# app.py +# --- AI & Transformers Ecosystem --- +transformers + +# --- Deployment & Demo (app.py) --- gradio==4.44.1 gradio-imageslider==0.0.20 spaces -trimesh +wandb \ No newline at end of file diff --git a/tests/test_masked_fusion.py b/tests/test_masked_fusion.py new file mode 100644 index 0000000..b2678e4 --- /dev/null +++ b/tests/test_masked_fusion.py @@ -0,0 +1,20 @@ +import torch + +from promptda.model.masked_fusion import MaskedLocalFusion + + +def test_masked_local_fusion_preserves_background(): + module = MaskedLocalFusion(in_channels=4, roi_output_size=7, sampling_ratio=2) + with torch.no_grad(): + module.projector.weight.zero_() + module.projector.bias.fill_(1.0) + + f_global = torch.zeros(1, 4, 8, 8) + boxes = [torch.tensor([[0.0, 0.0, 4.0, 4.0]], dtype=torch.float32)] + + out = module(f_global, boxes) + + # Outside top-left quadrant must remain exactly zero. + outside = out.clone() + outside[:, :, 0:4, 0:4] = 0 + assert torch.all(outside == 0), "Background changed outside the box mask" diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/loss.py b/training/loss.py new file mode 100644 index 0000000..d0986fe --- /dev/null +++ b/training/loss.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + +class ScaleAndShiftInvariantLoss(nn.Module): + def _compute_scale_shift(self, pred, target): + """Compute scale and shift for aligning pred to target in least-squares sense.""" + mean_pred = pred.mean() + mean_target = target.mean() + + var_pred = ((pred - mean_pred) ** 2).mean() + cov = ((pred - mean_pred) * (target - mean_target)).mean() + + if var_pred < 1e-8: + return torch.tensor(1.0, device=pred.device), mean_target - mean_pred + + scale = cov / var_pred + shift = mean_target - scale * mean_pred + + return scale, shift + + def forward(self, pred, target, mask=None): + if isinstance(pred, (list, tuple)): + pred = pred[0] + if isinstance(target, (list, tuple)): + target = target[0] + + if mask is None: + mask = (target > 0) + + if isinstance(mask, torch.Tensor): + mask = mask.to(device=target.device).bool() + else: + mask = torch.as_tensor(mask, dtype=torch.bool, device=target.device) + + pred_flat = pred[mask].flatten() + target_flat = target[mask].flatten() + + # Skip tiny or invalid masks to avoid degenerate scale/shift solves. + if pred_flat.numel() < 50: + return torch.tensor(0.0, device=pred.device, requires_grad=True) + + # Log-domain stabilization and NaN/Inf cleanup. + pred_flat = torch.log1p(pred_flat.clamp(min=0)) + target_flat = torch.log1p(target_flat.clamp(min=0)) + pred_flat = torch.nan_to_num(pred_flat, nan=0.0, posinf=0.0, neginf=0.0) + target_flat = torch.nan_to_num(target_flat, nan=0.0, posinf=0.0, neginf=0.0) + + scale, shift = self._compute_scale_shift(pred_flat, target_flat) + + # --- NEW: clamp scale/shift để tránh degenerate solution --- + scale = scale.clamp(-10, 10) + shift = shift.clamp(-10, 10) + + pred_aligned = scale * pred_flat + shift + loss = torch.mean((pred_aligned - target_flat) ** 2) + loss = torch.nan_to_num(loss, nan=0.0, posinf=1e4, neginf=0.0) + return loss + + +class LocalROILoss(nn.Module): + """SSI loss computed within each bounding box ROI. + + Args: + boxes: list of per-image tensors, each shaped (N_i, 4) in xyxy pixel coords. + """ + def __init__(self, roi_output_size=7): + super().__init__() + self.roi_output_size = roi_output_size + self._ssi = ScaleAndShiftInvariantLoss() + + def forward(self, pred, target, boxes=None): + if isinstance(pred, (list, tuple)): + pred = pred[0] + if isinstance(target, (list, tuple)): + target = target[0] + + if boxes is None or not any(b is not None and len(b) > 0 for b in boxes): + return torch.tensor(0.0, device=pred.device, requires_grad=True) + + _, _, H, W = pred.shape + losses = [] + + for i, b in enumerate(boxes): + if b is None or len(b) == 0: + continue + for box in b: + x1 = box[0].long().clamp(0, W - 1) + y1 = box[1].long().clamp(0, H - 1) + x2 = box[2].long().clamp(x1 + 1, W) + y2 = box[3].long().clamp(y1 + 1, H) + + # Skip degenerate ROIs. + if (x2 - x1) < 4 or (y2 - y1) < 4: + continue + + pred_roi = pred[i:i+1, :, y1:y2, x1:x2] + target_roi = target[i:i+1, :, y1:y2, x1:x2] + + valid_mask = target_roi > 0 + if valid_mask.sum() < 50: + continue + + # Skip near-constant depth slices where SSI is unstable. + if target_roi[valid_mask].std() < 1e-4: + continue + + losses.append(self._ssi.forward(pred_roi, target_roi)) + + if not losses: + return torch.tensor(0.0, device=pred.device, requires_grad=True) + + loss = torch.stack(losses).mean() + loss = torch.nan_to_num(loss, nan=0.0, posinf=1e4, neginf=0.0) + return loss + + +class CombinedLoss(nn.Module): + def __init__(self, lambda_local=0.05, roi_size=7): + super().__init__() + self.ssi_loss = ScaleAndShiftInvariantLoss() + self.local_loss = LocalROILoss(roi_output_size=roi_size) + self.lambda_local = lambda_local + + def forward(self, pred, target, boxes=None): + # SSI loss calculation + loss_ssi = self.ssi_loss(pred, target) + + has_boxes = ( + boxes is not None + and len(boxes) > 0 + and any(b is not None and len(b) > 0 for b in boxes) + ) + + if has_boxes: + loss_local = self.local_loss(pred, target, boxes) + total = loss_ssi + self.lambda_local * loss_local + else: + # Detect device from tensor or list of tensors + dev = pred[0].device if isinstance(pred, (list, tuple)) else pred.device + loss_local = torch.tensor(0.0, device=dev) + total = loss_ssi + + total = torch.nan_to_num(total, nan=0.0, posinf=1e4, neginf=0.0) + + return total, { + "loss_ssi": loss_ssi.item(), + "loss_local": loss_local.item(), + "loss_total": total.item(), + "has_boxes": has_boxes, + } \ No newline at end of file diff --git a/training/metrics.py b/training/metrics.py new file mode 100644 index 0000000..2933459 --- /dev/null +++ b/training/metrics.py @@ -0,0 +1,98 @@ +""" +promptda/model/metrics.py + +Evaluation metrics for depth estimation. +Used to compare baseline and MLF runs. +""" + +import torch + + +@torch.no_grad() +def compute_depth_metrics( + pred: torch.Tensor, # [B, 1, H, W] + gt: torch.Tensor, # [B, 1, H, W] +) -> dict: + """ + Compute comprehensive depth estimation metrics on one batch. + + Returns dict with: + `AbsRel`: Mean absolute relative error (lower is better). + `MAE`: Mean absolute error in meters (lower is better). + `RMSE`: Root mean squared error (lower is better). + `Log10`: Mean log10 error (lower is better). + `delta1`: Ratio of valid pixels with threshold < 1.25 (higher is better). + `delta2`: Ratio with threshold < $1.25^2$ (higher is better). + `delta3`: Ratio with threshold < $1.25^3$ (higher is better). + `SILog`: Scale-invariant logarithmic error (lower is better). + """ + # Ensure pred and gt are [B, H, W] by squeezing the channel dimension if it exists + if pred.dim() == 4: + pred = pred.squeeze(1) + if gt.dim() == 4: + gt = gt.squeeze(1) + + mask = gt > 0 + pred_m = pred[mask] + gt_m = gt[mask] + + if pred_m.numel() == 0: + return { + "AbsRel": 0.0, + "MAE": 0.0, + "RMSE": 0.0, + "Log10": 0.0, + "delta1": 0.0, + "delta2": 0.0, + "delta3": 0.0, + "SILog": 0.0, + } + + # === Error metrics without scaling === + mae = torch.abs(pred_m - gt_m).mean().item() + mse = ((pred_m - gt_m) ** 2).mean().item() + rmse = torch.sqrt(torch.tensor(mse)).item() + + # === Log metrics === + log10_error = torch.abs(torch.log10(pred_m + 1e-8) - torch.log10(gt_m + 1e-8)).mean().item() + + # Scale-invariant logarithmic error (SILog) + log_pred = torch.log(pred_m + 1e-8) + log_gt = torch.log(gt_m + 1e-8) + si_log = torch.sqrt(((log_pred - log_gt) ** 2).mean() - ((log_pred - log_gt).mean() ** 2)).item() + + # === Relative error (with median scaling) === + scale = gt_m.median() / (pred_m.median() + 1e-8) + pred_scaled = (pred_m * scale).clamp(min=1e-8) + + abs_rel = ((pred_scaled - gt_m).abs() / (gt_m + 1e-8)).mean().item() + + # === Threshold metrics (accuracy) === + ratio = torch.max(pred_scaled / (gt_m + 1e-8), gt_m / (pred_scaled + 1e-8)) + delta1 = (ratio < 1.25 ).float().mean().item() + delta2 = (ratio < 1.25 ** 2).float().mean().item() + delta3 = (ratio < 1.25 ** 3).float().mean().item() + + return { + "AbsRel": abs_rel, + "MAE": mae, + "RMSE": rmse, + "Log10": log10_error, + "delta1": delta1, + "delta2": delta2, + "delta3": delta3, + "SILog": si_log, + } + + +def aggregate_metrics(metrics_list: list[dict]) -> dict: + """ + Average a list of metric dictionaries into one dictionary. + """ + if not metrics_list: + return {} + keys = metrics_list[0].keys() + return { + k: sum(m[k] for m in metrics_list) / len(metrics_list) + for k in keys + } \ No newline at end of file diff --git a/training/optimizer.py b/training/optimizer.py new file mode 100644 index 0000000..8bd7fe8 --- /dev/null +++ b/training/optimizer.py @@ -0,0 +1,48 @@ +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR +import torch.nn as nn + + +def build_optimizer(model: nn.Module, lr_mlf: float) -> AdamW: + """ + Chỉ optimize MLF projector. + DPT head và DINOv2 đã được freeze trong PromptDA._apply_freeze_strategy() + → requires_grad=False → không cần đưa vào optimizer. + + Gradient vẫn chảy qua DPT head vì các activation tensor vẫn tracked, + chỉ có weight của DPT head là không update. + """ + mlf_params = [p for p in model.mlf.parameters() if p.requires_grad] + + if not mlf_params: + raise ValueError( + "No trainable parameters in MLF. " + "Check that use_mlf=True and MaskedLocalFusion is initialized." + ) + + return AdamW( + mlf_params, + lr=lr_mlf, + weight_decay=1e-4, + ) + + +def build_scheduler( + optimizer: AdamW, + steps_per_epoch: int, + epochs: int, + lr_mlf: float, +) -> OneCycleLR: + """ + OneCycleLR chỉ cho MLF — warmup 10% → peak → cooldown. + """ + return OneCycleLR( + optimizer, + max_lr=lr_mlf, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + pct_start=0.1, # 10% warmup + div_factor=10, # initial_lr = lr_mlf / 10 + final_div_factor=100, # min_lr = lr_mlf / 100 + anneal_strategy='cos', + ) \ No newline at end of file diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..8999dbe --- /dev/null +++ b/training/train.py @@ -0,0 +1,226 @@ +""" +training/train.py + +Entry point for PromptDA baseline evaluation and MLF training. + +Dataset split convention: + data/ARKitScenes/Training/ → train + data/ARKitScenes/Validation/ → evaluate + +Reference commands: + + # Baseline: zero-shot evaluation (no training) + python training/train.py --use_mlf false --run_name baseline + + # Experiment: train only the MLF projector + python training/train.py --use_mlf true --run_name mlf +""" + +import argparse +import os +import random +import sys + +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import torch +from torch.utils.data import DataLoader + +try: + import wandb +except ImportError: + wandb = None + +from dataset.dataset import MyARKitScenesDataset, collate_fn +from promptda.promptda import PromptDA +from promptda.utils.logger import Log +from training.optimizer import build_optimizer, build_scheduler +from training.trainer import Trainer + + +def str2bool(value: str) -> bool: + return value.lower() in {"1", "true", "t", "yes", "y"} + +def parse_args(): + p = argparse.ArgumentParser(description="PromptDA baseline vs MLF training") + + # Data + p.add_argument("--data_root", type=str, default="data/ARKitScenes/data/upsampling") + p.add_argument("--max_samples", type=int, default=None) + p.add_argument("--num_workers", type=int, default=4) + + # Model + p.add_argument("--encoder", type=str, default="vitl", + choices=["vits", "vitb", "vitl"]) + p.add_argument( + "--pretrained_path", + type=str, + default=None, + help="Local .ckpt path or Hugging Face repo id. None uses encoder default.", + ) + p.add_argument( + "--use_mlf", + type=str2bool, + default=True, + help="false = baseline (zero-shot) | true = train MLF projector", + ) + + # Training (used only when --use_mlf=true) + p.add_argument("--run_name", type=str, default="experiment") + p.add_argument("--epochs", type=int, default=20) + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--lr_mlf", type=float, default=1e-4, + help="Learning rate for the MLF projector") + + # Checkpoint + p.add_argument("--checkpoint_dir", type=str, default="./checkpoints") + p.add_argument("--resume", type=str, default=None, + help="Resume training from a trainer checkpoint") + + # Seed + p.add_argument("--seed", type=int, default=42, + help="Random seed for reproducibility") + + # Weights & Biases + p.add_argument("--use_wandb", type=str2bool, default=True, + help="Enable logging to Weights & Biases") + p.add_argument("--wandb_mode", type=str, default="online", choices=["online", "offline", "disabled"]) + + return p.parse_args() + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Enable deterministic CuDNN behavior (may reduce performance slightly). + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def build_loader(data_root, split, batch_size, num_workers, shuffle): + ds = MyARKitScenesDataset( + root=data_root, + split=split, + ) + loader = DataLoader( + ds, batch_size=batch_size, shuffle=shuffle, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=True, + ) + Log.info(f"{split}: {len(ds)} samples") + return loader + + +def main(): + args = parse_args() + set_seed(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + Log.info(f"Device : {device}") + Log.info(f"Seed : {args.seed}") + Log.info(f"Run : {args.run_name}") + Log.info(f"Mode : {'EXPERIMENT (train MLF)' if args.use_mlf else 'BASELINE (zero-shot)'}") + Log.info(f"Encoder : {args.encoder}") + + # Validation loader is required for both baseline and training modes. + val_loader = build_loader( + args.data_root, "val", args.batch_size, args.num_workers, shuffle=False, + ) + + # Model + model = PromptDA.from_pretrained( + pretrained_model_name_or_path=args.pretrained_path, + encoder=args.encoder, + use_mlf=args.use_mlf, + ).to(device) + + ckpt_dir = f"{args.checkpoint_dir}/{args.run_name}_{args.encoder}_{args.seed}" + os.makedirs(ckpt_dir, exist_ok=True) + + # Optional W&B run + wandb_run = None + if args.use_wandb: + if wandb is None: + raise ImportError("wandb is not installed. Install it with: pip install wandb") + wandb_run = wandb.init( + project="ObjectPromptDA", + entity="ObjectPromptDA", + name=f"{args.run_name}_{args.encoder}_seed{args.seed}", + dir=ckpt_dir, + mode=args.wandb_mode, + config=vars(args), + tags=["mlf" if args.use_mlf else "baseline", args.encoder], + ) + + # Trainer + trainer = Trainer( + model=model, + optimizer=None, + scheduler=None, + device=device, + ckpt_dir=ckpt_dir, + wandb_run=wandb_run, + ) + + # Baseline: zero-shot evaluation on validation split. + if not args.use_mlf: + Log.info("Baseline: running zero-shot evaluation on validation split...") + val_loss, metrics = trainer.eval_epoch(val_loader, epoch=0) + Log.info( + f"[BASELINE] AbsRel={metrics['AbsRel']:.4f} | " + f"δ<1.25={metrics['delta1']:.4f} | " + f"δ<1.25²={metrics['delta2']:.4f} | " + f"δ<1.25³={metrics['delta3']:.4f}" + ) + trainer.history["train_loss"].append(val_loss) + trainer.history["val_loss"].append(val_loss) + trainer.history["AbsRel"].append(metrics["AbsRel"]) + trainer.history["delta1"].append(metrics["delta1"]) + trainer.history["delta2"].append(metrics["delta2"]) + trainer.history["delta3"].append(metrics["delta3"]) + + trainer.plot_history() + trainer.save_checkpoint(epoch=0, metrics=metrics, tag="baseline_zeroshot") + trainer.log_wandb_metrics( + epoch=0, + train_loss=None, + val_loss=val_loss, + metrics=metrics, + stage="baseline", + ) + if wandb_run is not None: + wandb_run.finish() + return + + # Experiment: train MLF on training split and evaluate on validation split. + train_loader = build_loader( + args.data_root, "train", args.batch_size, args.num_workers, shuffle=True, + ) + + optimizer = build_optimizer(model, lr_mlf=args.lr_mlf) + scheduler = build_scheduler( + optimizer, + steps_per_epoch=len(train_loader), + epochs=args.epochs, + lr_mlf=args.lr_mlf, + ) + trainer.optimizer = optimizer + trainer.scheduler = scheduler + + if args.resume: + trainer.load_checkpoint(args.resume) + Log.info(f"Resumed from: {args.resume}") + + trainer.fit(train_loader, val_loader, epochs=args.epochs) + if wandb_run is not None: + wandb_run.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/training/trainer.py b/training/trainer.py new file mode 100644 index 0000000..3c4e06c --- /dev/null +++ b/training/trainer.py @@ -0,0 +1,301 @@ +"""Trainer utilities for PromptDA fine-tuning and evaluation.""" + +import json +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from tqdm import tqdm + +from promptda.utils.logger import Log +from training.loss import CombinedLoss +from training.metrics import aggregate_metrics, compute_depth_metrics + + +class Trainer: + """Owns train/eval loops, checkpointing, and metric plotting.""" + + def __init__( + self, + model: nn.Module, + optimizer, + scheduler, + device: torch.device, + ckpt_dir: str, + wandb_run: Any = None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.device = device + self.ckpt_dir = Path(ckpt_dir) + self.loss_fn = CombinedLoss() + self.best_abs_rel = float("inf") + self.wandb_run = wandb_run + self.history = { + "train_loss": [], + "val_loss": [], + "AbsRel": [], + "MAE": [], + "RMSE": [], + "Log10": [], + "delta1": [], + "delta2": [], + "delta3": [], + "SILog": [], + } + + self.ckpt_dir.mkdir(parents=True, exist_ok=True) + + def train_epoch(self, loader, epoch: int) -> float: + """Run one training epoch.""" + self.model.train() + total_loss = 0.0 + + for batch_idx, batch in enumerate(tqdm(loader, desc=f"Train {epoch}", leave=False)): + image = batch["color_img"].to(self.device) # (B, 3, H, W) + depth_gt = batch["high_res_depth_img"].to(self.device) # (B, 1, H, W) + prompt = batch["low_res_depth_img"].to(self.device) # (B, 1, h, w) + boxes = [b.to(self.device) for b in batch["boxes"]] + boxes_image = [b.to(self.device) for b in batch.get("boxes_image", [])] + + # First-batch sanity checks. + if epoch == 1 and batch_idx == 0: + total_boxes = sum(len(b) for b in boxes) + Log.info(f"[DEBUG] Batch boxes: {[len(b) for b in boxes]} — total={total_boxes}") + if total_boxes == 0: + Log.warn( + "[DEBUG] All boxes are empty. MLF may be bypassed. " + "Verify offline box generation." + ) + + pred = self.model(image, prompt, boxes) + + if pred.shape[-2:] != depth_gt.shape[-2:]: + pred = torch.nn.functional.interpolate( + pred, + size=depth_gt.shape[-2:], + mode="bilinear", + align_corners=False, + ) + + loss, _ = self.loss_fn(pred, depth_gt, boxes_image) + + if self.optimizer is None: + raise RuntimeError("Optimizer is None in training mode.") + + # First-batch gradient sanity check for MLF. + if epoch == 1 and batch_idx == 0: + self.optimizer.zero_grad() + loss.backward() + mlf_params = list(self.model.mlf.parameters()) + grads = [p.grad for p in mlf_params if p.grad is not None] + if grads: + grad_norm = sum(g.norm().item() for g in grads) + Log.info(f"[DEBUG] MLF gradient norm = {grad_norm:.6f}") + else: + Log.warn("[DEBUG] MLF has no gradients. Check graph and inputs.") + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + total_loss += loss.item() + continue + + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + total_loss += loss.item() + + return total_loss / len(loader) + + @torch.no_grad() + def eval_epoch(self, loader, epoch: int) -> tuple[float, dict]: + """Run one evaluation epoch.""" + self.model.eval() + total_loss = 0.0 + metrics_list = [] + + for batch in tqdm(loader, desc=f"Val {epoch}", leave=False): + image = batch["color_img"].to(self.device) + depth_gt = batch["high_res_depth_img"].to(self.device) + prompt = batch["low_res_depth_img"].to(self.device) + + boxes_feat = [b.to(self.device) for b in batch["boxes"]] + boxes_image = [b.to(self.device) for b in batch["boxes_image"]] + + pred = self.model(image, prompt, boxes_feat) + + if pred.shape[-2:] != depth_gt.shape[-2:]: + pred = torch.nn.functional.interpolate( + pred, + size=depth_gt.shape[-2:], + mode="bilinear", + align_corners=False, + ) + + loss, loss_dict = self.loss_fn(pred, depth_gt, boxes_image) + + total_loss += loss.item() + metrics_list.append(compute_depth_metrics(pred, depth_gt)) + + avg_loss = total_loss / len(loader) + avg_metrics = aggregate_metrics(metrics_list) + return avg_loss, avg_metrics + + def save_checkpoint(self, epoch: int, metrics: dict, tag: str = "latest"): + """Save checkpoint state.""" + state = { + "epoch": epoch, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, + "metrics": metrics, + "history": self.history, + } + path = self.ckpt_dir / f"{tag}.pth" + torch.save(state, path) + return path + + def load_checkpoint(self, path: str): + """Load checkpoint state and restore history when available.""" + state = torch.load(path, map_location=self.device) + self.model.load_state_dict(state["model"]) + if self.optimizer is not None and state.get("optimizer") is not None: + self.optimizer.load_state_dict(state["optimizer"]) + if "history" in state: + self.history = state["history"] + + Log.info(f"Loaded checkpoint from {path} (epoch {state.get('epoch', 'unknown')})") + return state.get("epoch", 0) + + def plot_history(self): + """Save JSON metrics and a curve figure in the run directory.""" + if len(self.history["val_loss"]) == 0: + return + + with open(self.ckpt_dir / "metrics_history.json", "w") as f: + json.dump(self.history, f, indent=2) + + epochs = list(range(1, len(self.history["val_loss"]) + 1)) + has_train = len(self.history["train_loss"]) == len(epochs) + + plt.figure(figsize=(12, 8)) + + plt.subplot(2, 2, 1) + if has_train: + plt.plot(epochs, self.history["train_loss"], label="Train Loss", marker="o") + plt.plot(epochs, self.history["val_loss"], label="Val Loss", marker="o") + plt.title("Loss") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.grid(True) + plt.legend() + + plt.subplot(2, 2, 2) + plt.plot(epochs, self.history["AbsRel"], label="AbsRel", color="red", marker="o") + plt.title("AbsRel (lower is better)") + plt.xlabel("Epoch") + plt.ylabel("AbsRel") + plt.grid(True) + plt.legend() + + plt.subplot(2, 2, 3) + plt.plot(epochs, self.history["delta1"], label=r"$\delta < 1.25$", marker="o") + plt.plot(epochs, self.history["delta2"], label=r"$\delta < 1.25^2$", marker="o") + plt.plot(epochs, self.history["delta3"], label=r"$\delta < 1.25^3$", marker="o") + plt.title("Accuracy") + plt.xlabel("Epoch") + plt.ylabel("Ratio") + plt.grid(True) + plt.legend() + + plt.tight_layout() + plt.savefig(self.ckpt_dir / "training_curves.png", dpi=150) + plt.close() + + def log_wandb_metrics( + self, + epoch: int, + val_loss: float, + metrics: dict, + train_loss: float | None = None, + stage: str = "train", + ): + """Log metrics to Weights & Biases when a run is available.""" + if self.wandb_run is None: + return + + log_data = { + "epoch": epoch, + "val/loss": val_loss, + "val/AbsRel": metrics["AbsRel"], + "val/MAE": metrics["MAE"], + "val/RMSE": metrics["RMSE"], + "val/Log10": metrics["Log10"], + "val/delta1": metrics["delta1"], + "val/delta2": metrics["delta2"], + "val/delta3": metrics["delta3"], + "val/SILog": metrics["SILog"], + "stage": stage, + } + + if train_loss is not None: + log_data["train/loss"] = train_loss + + if self.optimizer is not None: + for idx, group in enumerate(self.optimizer.param_groups): + log_data[f"lr/group_{idx}"] = float(group.get("lr", 0.0)) + + if len(self.optimizer.param_groups) >= 2: + log_data["lr/dpt_head"] = float(self.optimizer.param_groups[0].get("lr", 0.0)) + log_data["lr/mlf"] = float(self.optimizer.param_groups[1].get("lr", 0.0)) + + self.wandb_run.log(log_data) + + def fit(self, train_loader, val_loader, epochs: int): + """Run full training and update metrics/checkpoints each epoch.""" + start_epoch = len(self.history["train_loss"]) + 1 + + for epoch in range(start_epoch, epochs + 1): + train_loss = self.train_epoch(train_loader, epoch) + val_loss, metrics = self.eval_epoch(val_loader, epoch) + + self.history["train_loss"].append(train_loss) + self.history["val_loss"].append(val_loss) + self.history["AbsRel"].append(metrics["AbsRel"]) + self.history["MAE"].append(metrics["MAE"]) + self.history["RMSE"].append(metrics["RMSE"]) + self.history["Log10"].append(metrics["Log10"]) + self.history["delta1"].append(metrics["delta1"]) + self.history["delta2"].append(metrics["delta2"]) + self.history["delta3"].append(metrics["delta3"]) + self.history["SILog"].append(metrics["SILog"]) + + Log.info( + f"Epoch {epoch:03d}/{epochs} | " + f"train={train_loss:.4f} | val={val_loss:.4f} | " + f"AbsRel={metrics['AbsRel']:.4f} | " + f"δ<1.25={metrics['delta1']:.4f}" + ) + + if metrics["AbsRel"] < self.best_abs_rel: + self.best_abs_rel = metrics["AbsRel"] + self.save_checkpoint(epoch, metrics, tag="best") + Log.info(f" ✓ best.pth saved → AbsRel={self.best_abs_rel:.4f}") + + self.save_checkpoint(epoch, metrics, tag="latest") + self.plot_history() + self.log_wandb_metrics( + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + metrics=metrics, + stage="train", + ) + + Log.info(f"Training done. Best AbsRel = {self.best_abs_rel:.4f}") \ No newline at end of file