Skip to content

Commit 5da4e03

Browse files
committed
Adding options to provide second testing file for out of sample testing
1 parent 2c22c02 commit 5da4e03

1 file changed

Lines changed: 22 additions & 6 deletions

File tree

tools/model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,13 @@ def trim_by_quantile(
243243
if quantile >= 1.0:
244244
return train_df, test_df
245245
q = train_df["target_uj"].quantile(quantile)
246+
low = train_df["target_uj"].quantile(1 - quantile)
247+
high = train_df["target_uj"].quantile(quantile)
248+
249+
246250
return (
247-
train_df[train_df["target_uj"] <= q],
248-
test_df[test_df["target_uj"] <= q],
251+
train_df[(train_df["target_uj"] >= low) & (train_df["target_uj"] <= high)],
252+
test_df[(test_df["target_uj"] >= low) & (test_df["target_uj"] <= high)],
249253
)
250254

251255

@@ -393,6 +397,7 @@ def main() -> None:
393397
description="Train nonlinear PSYS model and distill to kernel-friendly linear+multi-LUT params."
394398
)
395399
parser.add_argument("logfiles", nargs="+", type=Path)
400+
parser.add_argument("--test-data", type=Path, help="Optional separate dataset for testing")
396401
parser.add_argument("--mode", choices=("delta",), default="delta")
397402
parser.add_argument("--alpha", type=float, default=10.0, help="L2 strength for non-negative linear fit")
398403
parser.add_argument("--test-frac", type=float, default=0.2)
@@ -401,15 +406,26 @@ def main() -> None:
401406
parser.add_argument("--random-seed", type=int, default=42)
402407
args = parser.parse_args()
403408

409+
# Check all files exist
404410
for path in args.logfiles:
405411
if not path.exists():
406412
raise FileNotFoundError(path)
413+
if args.test_data and not args.test_data.exists():
414+
raise FileNotFoundError(args.test_data)
407415

408-
df = gather_rows(args.logfiles, args.mode, args.min_target_uj)
409-
if df.empty or len(df) < 40:
416+
# Gather training data
417+
train_df = gather_rows(args.logfiles, args.mode, args.min_target_uj)
418+
if train_df.empty or len(train_df) < 40:
410419
raise RuntimeError("Not enough usable rows (need >= 40 after filtering)")
411420

412-
train_df, test_df = split_random(df, args.test_frac, args.random_seed)
421+
# Gather test data
422+
if args.test_data:
423+
test_df = gather_rows([args.test_data], args.mode, args.min_target_uj)
424+
if test_df.empty:
425+
raise RuntimeError("No usable rows in test-data file")
426+
else:
427+
train_df, test_df = split_random(train_df, args.test_frac, args.random_seed)
428+
413429
train_df, test_df = trim_by_quantile(train_df, test_df, args.trim_upper_quantile)
414430
if train_df.empty or test_df.empty:
415431
raise RuntimeError("Train/test split produced an empty set")
@@ -505,7 +521,7 @@ def main() -> None:
505521
)
506522

507523
print(f"# mode={args.mode} alpha={args.alpha}")
508-
print(f"# rows_total={len(df)} rows_train={len(train_df)} rows_test={len(test_df)}")
524+
print(f"# rows_train={len(train_df)} rows_test={len(test_df)}")
509525

510526
print_metrics("test metrics (current module defaults)", metric_bundle(y_test, pred_baseline_test))
511527
print_metrics("test metrics (linear non-negative ridge)", metric_bundle(y_test, pred_lin_test))

0 commit comments

Comments
 (0)