@@ -243,9 +243,13 @@ def trim_by_quantile(
243243 if quantile >= 1.0 :
244244 return train_df , test_df
245245 q = train_df ["target_uj" ].quantile (quantile )
246+ low = train_df ["target_uj" ].quantile (1 - quantile )
247+ high = train_df ["target_uj" ].quantile (quantile )
248+
249+
246250 return (
247- train_df [train_df ["target_uj" ] <= q ],
248- test_df [test_df ["target_uj" ] <= q ],
251+ train_df [( train_df ["target_uj" ] >= low ) & ( train_df [ "target_uj" ] <= high ) ],
252+ test_df [( test_df ["target_uj" ] >= low ) & ( test_df [ "target_uj" ] <= high ) ],
249253 )
250254
251255
@@ -393,6 +397,7 @@ def main() -> None:
393397 description = "Train nonlinear PSYS model and distill to kernel-friendly linear+multi-LUT params."
394398 )
395399 parser .add_argument ("logfiles" , nargs = "+" , type = Path )
400+ parser .add_argument ("--test-data" , type = Path , help = "Optional separate dataset for testing" )
396401 parser .add_argument ("--mode" , choices = ("delta" ,), default = "delta" )
397402 parser .add_argument ("--alpha" , type = float , default = 10.0 , help = "L2 strength for non-negative linear fit" )
398403 parser .add_argument ("--test-frac" , type = float , default = 0.2 )
@@ -401,15 +406,26 @@ def main() -> None:
401406 parser .add_argument ("--random-seed" , type = int , default = 42 )
402407 args = parser .parse_args ()
403408
409+ # Check all files exist
404410 for path in args .logfiles :
405411 if not path .exists ():
406412 raise FileNotFoundError (path )
413+ if args .test_data and not args .test_data .exists ():
414+ raise FileNotFoundError (args .test_data )
407415
408- df = gather_rows (args .logfiles , args .mode , args .min_target_uj )
409- if df .empty or len (df ) < 40 :
416+ # Gather training data
417+ train_df = gather_rows (args .logfiles , args .mode , args .min_target_uj )
418+ if train_df .empty or len (train_df ) < 40 :
410419 raise RuntimeError ("Not enough usable rows (need >= 40 after filtering)" )
411420
412- train_df , test_df = split_random (df , args .test_frac , args .random_seed )
421+ # Gather test data
422+ if args .test_data :
423+ test_df = gather_rows ([args .test_data ], args .mode , args .min_target_uj )
424+ if test_df .empty :
425+ raise RuntimeError ("No usable rows in test-data file" )
426+ else :
427+ train_df , test_df = split_random (train_df , args .test_frac , args .random_seed )
428+
413429 train_df , test_df = trim_by_quantile (train_df , test_df , args .trim_upper_quantile )
414430 if train_df .empty or test_df .empty :
415431 raise RuntimeError ("Train/test split produced an empty set" )
@@ -505,7 +521,7 @@ def main() -> None:
505521 )
506522
507523 print (f"# mode={ args .mode } alpha={ args .alpha } " )
508- print (f"# rows_total= { len ( df ) } rows_train={ len (train_df )} rows_test={ len (test_df )} " )
524+ print (f"# rows_train={ len (train_df )} rows_test={ len (test_df )} " )
509525
510526 print_metrics ("test metrics (current module defaults)" , metric_bundle (y_test , pred_baseline_test ))
511527 print_metrics ("test metrics (linear non-negative ridge)" , metric_bundle (y_test , pred_lin_test ))
0 commit comments