Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ __pycache__/
build/
results/
checkpoints/
data/
develop-eggs/
dist/
data/ARKitScenes/data
downloads/
eggs/
.eggs/
Expand Down Expand Up @@ -188,3 +188,7 @@ glove
*.ipynb
*.ply
3rdparty

# others
promptda/scripts/yolov8n.pt
data/ARKitScenes/data
210 changes: 210 additions & 0 deletions compare_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Comparison script: Baseline vs MLF inference.

Runs both baseline and MLF models on the same dataset and compares metrics side-by-side.

Usage:
python compare_inference.py --data_root data/ARKitScenes/data/upsampling \
--encoder vitl \
--mlf_checkpoint checkpoints/experiment/best.pth \
--output_dir results/comparison
"""

import argparse
import json
import os
import sys
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add project root to path
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))

from dataset.arkitscene import ARKitScenesDataset, collate_fn
from promptda.promptda import PromptDA
from promptda.utils.logger import Log
from training.metrics import compute_depth_metrics, aggregate_metrics


def parse_args():
p = argparse.ArgumentParser(description="PromptDA Baseline vs MLF Comparison")

# Data
p.add_argument("--data_root", type=str, default="data/ARKitScenes/data/upsampling")
p.add_argument("--split", type=str, default="Validation", choices=["Training", "Validation"])
p.add_argument("--image_size", type=int, nargs=2, default=[196, 252])
p.add_argument("--num_workers", type=int, default=4)

# Model
p.add_argument("--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl"])
p.add_argument("--pretrained_path", type=str, default=None,
help="Baseline pretrained checkpoint")
p.add_argument("--mlf_checkpoint", type=str, required=True,
help="MLF trained checkpoint")

# Output
p.add_argument("--output_dir", type=str, default="results/comparison")
p.add_argument("--batch_size", type=int, default=4)

return p.parse_args()


def load_checkpoint(ckpt_path, model, device):
"""Load checkpoint weights."""
if not os.path.exists(ckpt_path):
Log.warn(f"Checkpoint not found: {ckpt_path}")
return False

Log.info(f"Loading checkpoint: {ckpt_path}")
state = torch.load(ckpt_path, map_location=device)

if "model" in state:
state_dict = state["model"]
else:
state_dict = state

model.load_state_dict(state_dict, strict=False)
return True


def run_inference(model, loader, device, model_name):
"""Run inference and collect metrics."""
Log.info(f"Running {model_name} inference...")
all_metrics = []

with torch.no_grad():
for batch in tqdm(loader, desc=model_name, leave=False):
image = batch["image"].to(device)
depth_gt = batch["depth_gt"].to(device)
prompt = batch["prompt"].to(device)
boxes = [b.to(device) for b in batch["boxes"]]

pred = model(image, prompt, boxes=boxes if boxes[0].numel() > 0 else None)
metrics = compute_depth_metrics(pred, depth_gt)
all_metrics.append(metrics)

return aggregate_metrics(all_metrics)


def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Log.info(f"Device: {device}")
Log.info(f"Encoder: {args.encoder}")
Log.info(f"Split: {args.split}")

# Create output directory
os.makedirs(args.output_dir, exist_ok=True)

# Load dataset
Log.info(f"Loading {args.split} dataset...")
dataset = ARKitScenesDataset(
data_root=args.data_root,
split=args.split,
image_size=tuple(args.image_size),
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
Log.info(f"Dataset size: {len(dataset)}")

# Load baseline model
Log.info(f"Loading baseline model (no MLF)...")
baseline_model = PromptDA.from_pretrained(
pretrained_model_name_or_path=args.pretrained_path,
encoder=args.encoder,
use_mlf=False,
).to(device).eval()

# Load MLF model
Log.info(f"Loading MLF model...")
mlf_model = PromptDA.from_pretrained(
pretrained_model_name_or_path=args.pretrained_path,
encoder=args.encoder,
use_mlf=True,
).to(device).eval()

# Load MLF checkpoint
if not load_checkpoint(args.mlf_checkpoint, mlf_model, device):
Log.warn("Using random MLF weights (this will give poor results)")

# Run inference for both models
baseline_metrics = run_inference(baseline_model, loader, device, "Baseline")
mlf_metrics = run_inference(mlf_model, loader, device, "MLF")

# Compute improvements
improvements = {}
for key in baseline_metrics.keys():
baseline_val = baseline_metrics[key]
mlf_val = mlf_metrics[key]

# For metrics where lower is better (loss, error metrics)
if key in ["AbsRel", "MAE", "RMSE", "Log10", "SILog"]:
improvement = (baseline_val - mlf_val) / (baseline_val + 1e-8) * 100
improvements[key] = improvement
# For metrics where higher is better (delta)
else:
improvement = (mlf_val - baseline_val) / (baseline_val + 1e-8) * 100
improvements[key] = improvement

# Print comparison
Log.info("=" * 80)
Log.info(f"Baseline vs MLF Comparison ({args.split})")
Log.info("=" * 80)
Log.info(f"{'Metric':<15} {'Baseline':<15} {'MLF':<15} {'Improvement':<15}")
Log.info("-" * 80)

for key in sorted(baseline_metrics.keys()):
baseline_val = baseline_metrics[key]
mlf_val = mlf_metrics[key]
improvement = improvements[key]

improvement_sign = "↓" if improvement < 0 else "↑"
log_str = f"{key:<15} {baseline_val:<15.6f} {mlf_val:<15.6f} {improvement_sign} {abs(improvement):>10.2f}%"
Log.info(log_str)

Log.info("=" * 80)

# Save results
results = {
"baseline": baseline_metrics,
"mlf": mlf_metrics,
"improvements": improvements,
}

results_file = os.path.join(args.output_dir, "comparison_results.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
Log.info(f"Results saved to: {results_file}")

# Save comparison as text
txt_file = os.path.join(args.output_dir, "comparison_results.txt")
with open(txt_file, "w") as f:
f.write(f"Baseline vs MLF Comparison ({args.split})\n")
f.write("=" * 80 + "\n")
f.write(f"{'Metric':<15} {'Baseline':<15} {'MLF':<15} {'Improvement':<15}\n")
f.write("-" * 80 + "\n")
for key in sorted(baseline_metrics.keys()):
baseline_val = baseline_metrics[key]
mlf_val = mlf_metrics[key]
improvement = improvements[key]
improvement_sign = "↓" if improvement < 0 else "↑"
f.write(f"{key:<15} {baseline_val:<15.6f} {mlf_val:<15.6f} {improvement_sign} {abs(improvement):>10.2f}%\n")
f.write("=" * 80 + "\n")
Log.info(f"Comparison saved to: {txt_file}")

return results


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions data/ARKitScenes/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.idea
venv
*pyc
online_prepared_data
offline_prepared_data
threedod/sample_data
71 changes: 71 additions & 0 deletions data/ARKitScenes/CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
11 changes: 11 additions & 0 deletions data/ARKitScenes/CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Contribution Guide

Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.

While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.

## Before you get started

By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).

We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
Loading