From 35d407c285430e6358c334f7bdf6c74e5f9e6300 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Wed, 1 Apr 2026 22:48:03 +0200 Subject: [PATCH 1/9] feat: add reversibility check for constrained leapfrog integrator Phase-aware reverse check: observe during warmup, enforce during sampling. Forward+backward leapfrog in leapfrog_constrained_checked detects non-reversible RATTLE/SHAKE projections. Non-reversible steps terminate the NUTS tree and reject the proposal (same as divergence handling). Internal tolerance fixed at 0.5 * eps^2 (not exposed in R API). Non-reversible counts stored per iteration and surfaced through nuts_diagnostics alongside divergence counts. R layer: reverse_check parameter in bgm(), threaded through spec, validation, and run_sampler dispatch to C++ backends (GGM, OMRF, mixed). C++ layer: SamplerConfig, sampler_base, nuts_sampler, hmc_sampler, leapfrog, nuts, hmc, step_result, chain_result all extended. Tests: 6 tests covering parameter validation, output structure, and integration with continuous GGM edge selection. --- R/RcppExports.R | 16 ++- R/bgm.R | 8 ++ R/bgm_spec.R | 27 ++-- R/build_output.R | 3 + R/nuts_diagnostics.R | 20 ++- R/run_sampler.R | 9 +- R/validate_sampler.R | 20 +-- man/bgm.Rd | 7 + src/RcppExports.cpp | 48 +++++-- src/ggm_gradient_interface.cpp | 63 +++++++++ src/mcmc/algorithms/hmc.cpp | 32 ++++- src/mcmc/algorithms/hmc.h | 6 +- src/mcmc/algorithms/leapfrog.cpp | 34 +++++ src/mcmc/algorithms/leapfrog.h | 74 +++++++++++ src/mcmc/algorithms/nuts.cpp | 79 +++++++++-- src/mcmc/algorithms/nuts.h | 8 +- src/mcmc/execution/chain_result.h | 6 +- src/mcmc/execution/chain_runner.cpp | 3 +- src/mcmc/execution/sampler_config.h | 6 + src/mcmc/execution/step_result.h | 8 +- src/mcmc/samplers/hmc_sampler.h | 10 +- src/mcmc/samplers/nuts_sampler.h | 10 +- src/mcmc/samplers/sampler_base.h | 9 ++ src/sample_ggm.cpp | 4 +- src/sample_mixed.cpp | 4 +- src/sample_omrf.cpp | 4 +- tests/testthat/test-fit-object-contract.R | 6 +- tests/testthat/test-reversibility-check.R | 155 ++++++++++++++++++++++ tests/testthat/test-validate-sampler.R | 4 +- 29 files changed, 609 insertions(+), 74 deletions(-) create mode 100644 tests/testthat/test-reversibility-check.R diff --git a/R/RcppExports.R b/R/RcppExports.R index 2bd030d0..b0cd6226 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -45,6 +45,10 @@ ggm_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, suf_stat, .Call(`_bgms_ggm_test_leapfrog_constrained`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in) } +ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor = 0.5, inv_mass_in = NULL) { + .Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in) +} + .compute_ess_cpp <- function(array3d) { .Call(`_bgms_compute_ess_cpp`, array3d) } @@ -117,16 +121,16 @@ run_mixed_simulation_parallel <- function(mux_samples, disc_samples, muy_samples .Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, disc_samples, muy_samples, cont_samples, cross_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } -sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) { - .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable) +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL, reverse_check = TRUE) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, reverse_check) } -sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) { - .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable) +sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL, reverse_check = TRUE) { + .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable, reverse_check) } -sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) { - .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable) +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL, reverse_check = TRUE) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable, reverse_check) } compute_Vn_mfm_sbm <- function(num_variables, dirichlet_alpha, t_max, lambda) { diff --git a/R/bgm.R b/R/bgm.R index 942a1cce..a468faf9 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -306,6 +306,12 @@ #' matrix during warmup (NUTS only). If \code{FALSE}, use the identity #' matrix. Default: \code{TRUE}. #' +#' @param reverse_check Logical. If \code{TRUE}, run a forward-backward +#' reversibility check after each constrained leapfrog step (RATTLE). +#' Non-reversible steps are rejected, protecting against numerical +#' integration failures. Only applies when \code{edge_selection = TRUE} +#' and the sampler uses gradient-based proposals. Default: \code{TRUE}. +#' #' @param chains Integer. Number of parallel chains to run. Default: \code{4}. #' #' @param cores Integer. Number of CPU cores for parallel execution. @@ -440,6 +446,7 @@ bgm = function( hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, + reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), @@ -514,6 +521,7 @@ bgm = function( hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, + reverse_check = reverse_check, chains = chains, cores = cores, seed = seed, diff --git a/R/bgm_spec.R b/R/bgm_spec.R index 243547f5..5db8d711 100644 --- a/R/bgm_spec.R +++ b/R/bgm_spec.R @@ -280,6 +280,7 @@ bgm_spec = function(x, hmc_num_leapfrogs = 100L, nuts_max_depth = 10L, learn_mass_matrix = TRUE, + reverse_check = TRUE, chains = 4L, cores = parallel::detectCores(), seed = NULL, @@ -327,20 +328,21 @@ bgm_spec = function(x, # --- Sampler (needs is_continuous and edge_selection early) ------------------ sampler = validate_sampler( - update_method = update_method, - target_accept = target_accept, - iter = iter, - warmup = warmup, + update_method = update_method, + target_accept = target_accept, + iter = iter, + warmup = warmup, hmc_num_leapfrogs = hmc_num_leapfrogs, - nuts_max_depth = nuts_max_depth, + nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - chains = chains, - cores = cores, - seed = seed, - display_progress = display_progress, - is_continuous = is_continuous, - edge_selection = if(model_type == "compare") FALSE else edge_selection, - verbose = verbose + reverse_check = reverse_check, + chains = chains, + cores = cores, + seed = seed, + display_progress = display_progress, + is_continuous = is_continuous, + edge_selection = if(model_type == "compare") FALSE else edge_selection, + verbose = verbose ) # --- Build by model type ---------------------------------------------------- @@ -1078,6 +1080,7 @@ sampler_sublist = function(s) { hmc_num_leapfrogs = as.integer(s$hmc_num_leapfrogs), nuts_max_depth = as.integer(s$nuts_max_depth), learn_mass_matrix = s$learn_mass_matrix, + reverse_check = s$reverse_check, seed = as.integer(s$seed), progress_type = as.integer(s$progress_type) ) diff --git a/R/build_output.R b/R/build_output.R index df2a83f8..b4622600 100644 --- a/R/build_output.R +++ b/R/build_output.R @@ -305,6 +305,7 @@ build_output_bgm = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) @@ -330,6 +331,7 @@ build_output_bgm = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) @@ -593,6 +595,7 @@ build_output_mixed_mrf = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) diff --git a/R/nuts_diagnostics.R b/R/nuts_diagnostics.R index 854b2d2c..fa39e2da 100644 --- a/R/nuts_diagnostics.R +++ b/R/nuts_diagnostics.R @@ -110,11 +110,12 @@ check_warmup_complete = function(energy_mat) { # Returns: An invisible named list with: # - treedepth: Integer matrix (chains x iterations). # - divergent: Integer matrix (chains x iterations), 0/1. +# - non_reversible: Integer matrix (chains x iterations), 0/1. # - energy: Numeric matrix (chains x iterations). # - ebfmi: Numeric vector of per-chain E-BFMI values. # - warmup_check: Output of check_warmup_complete(). # - summary: List with total_divergences, max_tree_depth_hits, -# min_ebfmi, and warmup_incomplete (logical). +# min_ebfmi, total_non_reversible, and warmup_incomplete (logical). # ------------------------------------------------------------------------------ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) { nuts_chains = Filter(function(chain) { @@ -134,6 +135,12 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) divergent_mat = combine_diag("divergent__") energy_mat = combine_diag("energy__") + non_reversible_mat = if("non_reversible__" %in% names(nuts_chains[[1]])) { + combine_diag("non_reversible__") + } else { + matrix(0L, nrow = nrow(divergent_mat), ncol = ncol(divergent_mat)) + } + # E-BFMI per chain compute_ebfmi = function(energy) { mean(diff(energy)^2) / stats::var(energy) @@ -145,6 +152,7 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) # Summaries n_total = nrow(divergent_mat) * ncol(divergent_mat) total_divergences = sum(divergent_mat) + total_non_reversible = sum(non_reversible_mat) max_tree_depth_hits = sum(treedepth_mat == nuts_max_depth) min_ebfmi = min(ebfmi_per_chain) low_ebfmi_chains = which(ebfmi_per_chain < 0.2) @@ -169,6 +177,14 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) } } + if(total_non_reversible > 0) { + non_rev_rate = total_non_reversible / n_total + issues = c(issues, sprintf( + "Non-reversible steps: %d (%.3f%%) - constrained integrator round-trip failed", + total_non_reversible, 100 * non_rev_rate + )) + } + if(max_tree_depth_hits > 0) { if(depth_hit_rate > 0.01) { issues = c(issues, sprintf( @@ -203,11 +219,13 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) invisible(list( treedepth = treedepth_mat, divergent = divergent_mat, + non_reversible = non_reversible_mat, energy = energy_mat, ebfmi = ebfmi_per_chain, warmup_check = warmup_check, summary = list( total_divergences = total_divergences, + total_non_reversible = total_non_reversible, max_tree_depth_hits = max_tree_depth_hits, min_ebfmi = min_ebfmi, warmup_incomplete = any(warmup_check$warmup_incomplete) diff --git a/R/run_sampler.R b/R/run_sampler.R index 3eb7b684..3beffba0 100644 --- a/R/run_sampler.R +++ b/R/run_sampler.R @@ -95,7 +95,8 @@ run_sampler_ggm = function(spec) { target_acceptance = s$target_accept, max_tree_depth = s$nuts_max_depth, na_impute = m$na_impute, - missing_index_nullable = m$missing_index + missing_index_nullable = m$missing_index, + reverse_check = s$reverse_check ) out_raw @@ -152,7 +153,8 @@ run_sampler_omrf = function(spec) { target_acceptance = s$target_accept, max_tree_depth = s$nuts_max_depth, num_leapfrogs = s$hmc_num_leapfrogs, - pairwise_scaling_factors_nullable = p$pairwise_scaling_factors + pairwise_scaling_factors_nullable = p$pairwise_scaling_factors, + reverse_check = s$reverse_check ) out_raw @@ -211,7 +213,8 @@ run_sampler_mixed_mrf = function(spec) { num_leapfrogs = s$hmc_num_leapfrogs, na_impute = m$na_impute, missing_index_discrete_nullable = m$missing_index_discrete, - missing_index_continuous_nullable = m$missing_index_continuous + missing_index_continuous_nullable = m$missing_index_continuous, + reverse_check = s$reverse_check ) out_raw diff --git a/R/validate_sampler.R b/R/validate_sampler.R index 901de223..354b7bd6 100644 --- a/R/validate_sampler.R +++ b/R/validate_sampler.R @@ -93,6 +93,7 @@ validate_sampler = function(update_method, hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, + reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), seed = NULL, @@ -197,16 +198,17 @@ validate_sampler = function(update_method, progress_type = progress_type_from_display_progress(display_progress) list( - update_method = update_method, - target_accept = target_accept, - iter = iter, - warmup = warmup, + update_method = update_method, + target_accept = target_accept, + iter = iter, + warmup = warmup, hmc_num_leapfrogs = hmc_num_leapfrogs, - nuts_max_depth = nuts_max_depth, + nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - chains = chains, - cores = cores, - seed = seed, - progress_type = progress_type + reverse_check = reverse_check, + chains = chains, + cores = cores, + seed = seed, + progress_type = progress_type ) } diff --git a/man/bgm.Rd b/man/bgm.Rd index 19a38c58..917dab8d 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -28,6 +28,7 @@ bgm( hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, + reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), @@ -137,6 +138,12 @@ Default: \code{10}.} matrix during warmup (NUTS only). If \code{FALSE}, use the identity matrix. Default: \code{TRUE}.} +\item{reverse_check}{Logical. If \code{TRUE}, run a forward-backward +reversibility check after each constrained leapfrog step (RATTLE). +Non-reversible steps are rejected, protecting against numerical +integration failures. Only applies when \code{edge_selection = TRUE} +and the sampler uses gradient-based proposals. Default: \code{TRUE}.} + \item{chains}{Integer. Number of parallel chains to run. Default: \code{4}.} \item{cores}{Integer. Number of CPU cores for parallel execution. diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ad4e8d3c..6c00da84 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -191,6 +191,26 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// ggm_test_leapfrog_constrained_checked +Rcpp::List ggm_test_leapfrog_constrained_checked(const arma::vec& x0, const arma::vec& r0, double step_size, int n_steps, const arma::mat& suf_stat, int n, const arma::imat& edge_indicators, double pairwise_scale, double reverse_check_factor, Rcpp::Nullable inv_mass_in); +RcppExport SEXP _bgms_ggm_test_leapfrog_constrained_checked(SEXP x0SEXP, SEXP r0SEXP, SEXP step_sizeSEXP, SEXP n_stepsSEXP, SEXP suf_statSEXP, SEXP nSEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP reverse_check_factorSEXP, SEXP inv_mass_inSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::vec& >::type x0(x0SEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type r0(r0SEXP); + Rcpp::traits::input_parameter< double >::type step_size(step_sizeSEXP); + Rcpp::traits::input_parameter< int >::type n_steps(n_stepsSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type suf_stat(suf_statSEXP); + Rcpp::traits::input_parameter< int >::type n(nSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type edge_indicators(edge_indicatorsSEXP); + Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP); + Rcpp::traits::input_parameter< double >::type reverse_check_factor(reverse_check_factorSEXP); + Rcpp::traits::input_parameter< Rcpp::Nullable >::type inv_mass_in(inv_mass_inSEXP); + rcpp_result_gen = Rcpp::wrap(ggm_test_leapfrog_constrained_checked(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in)); + return rcpp_result_gen; +END_RCPP +} // compute_ess_cpp Rcpp::NumericVector compute_ess_cpp(Rcpp::NumericVector array3d); RcppExport SEXP _bgms_compute_ess_cpp(SEXP array3dSEXP) { @@ -516,8 +536,8 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable); -RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP) { +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const bool reverse_check); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP reverse_checkSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -543,13 +563,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)); + Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, reverse_check)); return rcpp_result_gen; END_RCPP } // sample_mixed_mrf -Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable); -RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP) { +Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable, const bool reverse_check); +RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP, SEXP reverse_checkSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -577,13 +598,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_discrete_nullable(missing_index_discrete_nullableSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_continuous_nullable(missing_index_continuous_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)); + Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); + rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable, reverse_check)); return rcpp_result_gen; END_RCPP } // sample_omrf -Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable); -RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP) { +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable, const bool reverse_check); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP, SEXP reverse_checkSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -611,7 +633,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type pairwise_scaling_factors_nullable(pairwise_scaling_factors_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)); + Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable, reverse_check)); return rcpp_result_gen; END_RCPP } @@ -642,6 +665,7 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_ggm_test_logp_and_gradient_full", (DL_FUNC) &_bgms_ggm_test_logp_and_gradient_full, 5}, {"_bgms_ggm_test_project_momentum", (DL_FUNC) &_bgms_ggm_test_project_momentum, 4}, {"_bgms_ggm_test_leapfrog_constrained", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained, 9}, + {"_bgms_ggm_test_leapfrog_constrained_checked", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained_checked, 10}, {"_bgms_compute_ess_cpp", (DL_FUNC) &_bgms_compute_ess_cpp, 1}, {"_bgms_compute_rhat_cpp", (DL_FUNC) &_bgms_compute_rhat_cpp, 1}, {"_bgms_compute_indicator_ess_cpp", (DL_FUNC) &_bgms_compute_indicator_ess_cpp, 1}, @@ -660,9 +684,9 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_run_ggm_simulation_parallel", (DL_FUNC) &_bgms_run_ggm_simulation_parallel, 9}, {"_bgms_sample_mixed_mrf_gibbs", (DL_FUNC) &_bgms_sample_mixed_mrf_gibbs, 11}, {"_bgms_run_mixed_simulation_parallel", (DL_FUNC) &_bgms_run_mixed_simulation_parallel, 16}, - {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 22}, - {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 24}, - {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 23}, + {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 25}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 25}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/ggm_gradient_interface.cpp b/src/ggm_gradient_interface.cpp index f28e1fbe..0a95f880 100644 --- a/src/ggm_gradient_interface.cpp +++ b/src/ggm_gradient_interface.cpp @@ -251,3 +251,66 @@ Rcpp::List ggm_test_leapfrog_constrained( Rcpp::Named("dH") = H_final - H0 ); } + + +// [[Rcpp::export]] +Rcpp::List ggm_test_leapfrog_constrained_checked( + const arma::vec& x0, + const arma::vec& r0, + double step_size, + int n_steps, + const arma::mat& suf_stat, + int n, + const arma::imat& edge_indicators, + double pairwise_scale, + double reverse_check_factor = 0.5, + Rcpp::Nullable inv_mass_in = R_NilValue) +{ + size_t p = edge_indicators.n_rows; + + arma::mat inc_prob(p, p, arma::fill::value(0.5)); + GGMModel model(static_cast(p), suf_stat, inc_prob, + edge_indicators, true, pairwise_scale); + + Memoizer::JointFn joint = [&model](const arma::vec& x) + -> std::pair { + return model.logp_and_gradient_full(x); + }; + Memoizer memo(joint); + + arma::vec inv_mass; + if(inv_mass_in.isNotNull()) { + inv_mass = Rcpp::as(inv_mass_in); + } else { + inv_mass = arma::ones(x0.n_elem); + } + + ProjectPositionFn proj_pos = [&model, &inv_mass](arma::vec& x) { + model.project_position(x, inv_mass); + }; + ProjectMomentumFn proj_mom = [&model, &inv_mass](arma::vec& r, const arma::vec& x) { + model.project_momentum(r, x, inv_mass); + }; + + arma::vec x = x0; + arma::vec r = r0; + int non_reversible_count = 0; + + for (int s = 0; s < n_steps; ++s) { + auto result = leapfrog_constrained_checked( + x, r, step_size, memo, inv_mass, proj_pos, proj_mom, + reverse_check_factor + ); + x = std::move(result.theta); + r = std::move(result.r); + if (!result.reversible) { + non_reversible_count++; + } + } + + return Rcpp::List::create( + Rcpp::Named("x") = Rcpp::wrap(x), + Rcpp::Named("r") = Rcpp::wrap(r), + Rcpp::Named("non_reversible_count") = non_reversible_count + ); +} diff --git a/src/mcmc/algorithms/hmc.cpp b/src/mcmc/algorithms/hmc.cpp index 51505f18..835308e4 100644 --- a/src/mcmc/algorithms/hmc.cpp +++ b/src/mcmc/algorithms/hmc.cpp @@ -184,7 +184,9 @@ StepResult hmc_step( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - SafeRNG& rng + SafeRNG& rng, + bool reverse_check, + double reverse_check_tol ) { Memoizer memo(joint); @@ -198,11 +200,31 @@ StepResult hmc_step( // Run num_leapfrogs constrained leapfrog steps arma::vec theta = init_theta; + bool non_reversible = false; for (int i = 0; i < num_leapfrogs; ++i) { - std::tie(theta, r) = leapfrog_constrained( - theta, r, step_size, memo, inv_mass_diag, - project_position, project_momentum - ); + if (reverse_check) { + auto checked = leapfrog_constrained_checked( + theta, r, step_size, memo, inv_mass_diag, + project_position, project_momentum, + reverse_check_tol + ); + theta = std::move(checked.theta); + r = std::move(checked.r); + if (!checked.reversible) { + non_reversible = true; + break; + } + } else { + std::tie(theta, r) = leapfrog_constrained( + theta, r, step_size, memo, inv_mass_diag, + project_position, project_momentum + ); + } + } + + // Non-reversible step forces rejection + if (non_reversible) { + return {init_theta, 0.0}; } double logp1 = memo.cached_log_post(theta); diff --git a/src/mcmc/algorithms/hmc.h b/src/mcmc/algorithms/hmc.h index 84eb949a..099aa8f3 100644 --- a/src/mcmc/algorithms/hmc.h +++ b/src/mcmc/algorithms/hmc.h @@ -150,6 +150,8 @@ StepResult hmc_step( * @param project_position SHAKE position projection callback * @param project_momentum RATTLE momentum projection callback * @param rng Thread-safe random number generator + * @param reverse_check Enable runtime reversibility check + * @param reverse_check_tol Factor for eps²-scaled reversibility tolerance * @return StepResult with accepted state and acceptance probability */ StepResult hmc_step( @@ -160,5 +162,7 @@ StepResult hmc_step( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - SafeRNG& rng + SafeRNG& rng, + bool reverse_check = true, + double reverse_check_tol = kReverseCheckFactor ); diff --git a/src/mcmc/algorithms/leapfrog.cpp b/src/mcmc/algorithms/leapfrog.cpp index 809afcef..65a4c306 100644 --- a/src/mcmc/algorithms/leapfrog.cpp +++ b/src/mcmc/algorithms/leapfrog.cpp @@ -70,6 +70,40 @@ std::pair leapfrog_constrained( } +ConstrainedLeapfrogResult leapfrog_constrained_checked( + const arma::vec& theta, + const arma::vec& r, + double eps, + Memoizer& memo, + const arma::vec& inv_mass_diag, + const ProjectPositionFn& project_position, + const ProjectMomentumFn& project_momentum, + double reverse_check_factor +) { + // --- Forward step --- + auto [theta_new, r_new] = leapfrog_constrained( + theta, r, eps, memo, inv_mass_diag, + project_position, project_momentum + ); + + // --- Backward step (negate momentum, step forward) --- + // Use a separate Memoizer so the caller's cache stays at theta_new. + Memoizer back_memo(memo.joint_fn); + arma::vec r_back = -r_new; + auto [theta_back, r_back_out] = leapfrog_constrained( + theta_new, r_back, eps, back_memo, inv_mass_diag, + project_position, project_momentum + ); + + // --- Reversibility check (eps^2-scaled max-norm) --- + double max_diff = arma::max(arma::abs(theta_back - theta)); + double tol = reverse_check_factor * eps * eps; + bool reversible = (max_diff <= tol); + + return {std::move(theta_new), std::move(r_new), reversible, max_diff}; +} + + LeapfrogJointResult leapfrog( const arma::vec& theta_init, const arma::vec& r_init, diff --git a/src/mcmc/algorithms/leapfrog.h b/src/mcmc/algorithms/leapfrog.h index 031ec383..45284591 100644 --- a/src/mcmc/algorithms/leapfrog.h +++ b/src/mcmc/algorithms/leapfrog.h @@ -163,6 +163,80 @@ std::pair leapfrog_constrained( ); +// --------------------------------------------------------------------------- +// Constrained leapfrog with runtime reversibility check +// --------------------------------------------------------------------------- + +/** + * Default factor for the epsilon-squared-scaled reversibility check. + * + * After each constrained leapfrog step the integrator takes a backward + * step and verifies that the position returns to within + * factor * eps^2 + * of the starting point in max-norm. + * + * The SHAKE direct solver's column-coupling produces O(eps^2) round-trip + * errors with a dimensionless coupling constant C in [0.001, 0.13] + * (measured on Wenchuan GGM, 18 variables, across step sizes 0.001-3.4). + * A factor of 0.5 provides generous headroom above the observed maximum + * C ~ 0.13 while catching genuine failures (divergent projection, + * near-singular Jacobian) where C >> 1. + * + * Using eps^2-scaling instead of an absolute tolerance avoids creating + * a hard ceiling on the adapted step size during warmup Stage 3c, + * which would otherwise cause catastrophic slowdown. + * + * See Zappa, Holmes-Cerfon & Goodman (2018). + */ +constexpr double kReverseCheckFactor = 0.5; + + +/** + * Result of a constrained leapfrog step with reversibility information. + */ +struct ConstrainedLeapfrogResult { + arma::vec theta; ///< Updated position + arma::vec r; ///< Updated momentum + bool reversible; ///< Whether the forward-backward check passed + double max_diff; ///< Max-norm of forward-backward position difference +}; + + +/** + * Constrained leapfrog step with a runtime reversibility check. + * + * Performs a forward constrained leapfrog step, then a backward step + * (negate momentum, step forward, negate again). If the round-trip + * position differs from the original by more than factor * eps^2 + * in max-norm, the step is flagged as non-reversible. + * + * This is the runtime analogue of Mici's ConstrainedLeapfrogIntegrator + * reverse check (Zappa et al., 2018; Lelièvre et al., 2019), adapted + * to use eps^2-scaled tolerance matching the O(eps^2) column-coupling + * error of the direct SHAKE solver. + * + * @param theta Current position (parameter vector) + * @param r Current momentum vector + * @param eps Step size for integration + * @param memo Memoizer caching gradient evaluations + * @param inv_mass_diag Diagonal of the inverse mass matrix + * @param project_position SHAKE position projection callback + * @param project_momentum RATTLE momentum projection callback + * @param reverse_check_factor Factor for eps^2-scaled tolerance + * @return ConstrainedLeapfrogResult with position, momentum, and reversibility flag + */ +ConstrainedLeapfrogResult leapfrog_constrained_checked( + const arma::vec& theta, + const arma::vec& r, + double eps, + Memoizer& memo, + const arma::vec& inv_mass_diag, + const ProjectPositionFn& project_position, + const ProjectMomentumFn& project_momentum, + double reverse_check_factor = kReverseCheckFactor +); + + /** * LeapfrogJointResult - Return type for multi-step leapfrog integration. * diff --git a/src/mcmc/algorithms/nuts.cpp b/src/mcmc/algorithms/nuts.cpp index df7b8bc5..6e68c7c8 100644 --- a/src/mcmc/algorithms/nuts.cpp +++ b/src/mcmc/algorithms/nuts.cpp @@ -65,24 +65,62 @@ BuildTreeResult build_tree( const arma::vec& inv_mass_diag, SafeRNG& rng, const ProjectPositionFn* project_position, - const ProjectMomentumFn* project_momentum + const ProjectMomentumFn* project_momentum, + bool reverse_check, + double reverse_check_tol ) { constexpr double Delta_max = 1000.0; if (j == 0) { // Base case: take a single leapfrog step arma::vec theta_new, r_new; + bool non_reversible = false; + double step_max_diff = 0.0; if (project_position && project_momentum) { - std::tie(theta_new, r_new) = leapfrog_constrained( + // Always run the checked variant so we can observe max_diff. + // The reverse_check flag controls whether we ACT on the result + // (i.e. terminate the tree). Observation is always on. + auto checked = leapfrog_constrained_checked( theta, r, v * step_size, memo, inv_mass_diag, - *project_position, *project_momentum + *project_position, *project_momentum, + reverse_check_tol ); + theta_new = std::move(checked.theta); + r_new = std::move(checked.r); + step_max_diff = checked.max_diff; + non_reversible = !checked.reversible; } else { std::tie(theta_new, r_new) = leapfrog_memo( theta, r, v * step_size, memo, inv_mass_diag ); } + // Non-reversible step terminates the tree only when reverse_check is on. + // When off (or during warmup observation mode), we record but don't act. + if (reverse_check && non_reversible) { + arma::vec p_sharp = inv_mass_diag % r_new; + BuildTreeResult result; + result.theta_min = theta_new; + result.theta_plus = theta_new; + result.r_min = r_new; + result.r_plus = r_new; + result.rho = r_new; + result.p_beg = r_new; + result.p_end = r_new; + result.r_prime = std::move(r_new); + result.theta_prime = std::move(theta_new); + result.p_sharp_beg = p_sharp; + result.p_sharp_end = std::move(p_sharp); + result.n_prime = 0; + result.s_prime = 0; + result.alpha = 0.0; + result.n_alpha = 1; + result.divergent = false; + result.non_reversible = true; + result.max_rev_diff = step_max_diff; + return result; + } + auto logp = memo.cached_log_post(theta_new); double kin = kinetic_energy(r_new, inv_mass_diag); int n_new = 1 * (log_u <= logp - kin); @@ -110,13 +148,16 @@ BuildTreeResult build_tree( result.alpha = alpha; result.n_alpha = 1; result.divergent = divergent; + result.non_reversible = non_reversible; // record even when not acting + result.max_rev_diff = step_max_diff; return result; } else { // Recursion: build the first subtree BuildTreeResult init_result = build_tree( theta, r, log_u, v, j - 1, step_size, theta_0, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); if (init_result.s_prime == 0) { @@ -125,6 +166,8 @@ BuildTreeResult build_tree( } bool divergent = init_result.divergent; + bool non_reversible = init_result.non_reversible; + double max_rev_diff = init_result.max_rev_diff; // Extract values from init subtree (move — init_result not used again) arma::vec theta_min = std::move(init_result.theta_min); @@ -147,7 +190,8 @@ BuildTreeResult build_tree( if (v == -1) { final_result = build_tree( theta_min, r_min, log_u, v, j - 1, step_size, theta_0, r0, logp0, - kin0, memo, inv_mass_diag, rng, project_position, project_momentum + kin0, memo, inv_mass_diag, rng, project_position, project_momentum, + reverse_check, reverse_check_tol ); // Update backward boundary theta_min = std::move(final_result.theta_min); @@ -155,7 +199,8 @@ BuildTreeResult build_tree( } else { final_result = build_tree( theta_plus, r_plus, log_u, v, j - 1, step_size, theta_0, r0, logp0, - kin0, memo, inv_mass_diag, rng, project_position, project_momentum + kin0, memo, inv_mass_diag, rng, project_position, project_momentum, + reverse_check, reverse_check_tol ); // Update forward boundary theta_plus = std::move(final_result.theta_plus); @@ -182,6 +227,8 @@ BuildTreeResult build_tree( result.alpha = alpha_prime + final_result.alpha; result.n_alpha = n_alpha_prime + final_result.n_alpha; result.divergent = divergent || final_result.divergent; + result.non_reversible = non_reversible || final_result.non_reversible; + result.max_rev_diff = std::max(max_rev_diff, final_result.max_rev_diff); return result; } @@ -195,6 +242,8 @@ BuildTreeResult build_tree( double alpha_double_prime = final_result.alpha; int n_alpha_double_prime = final_result.n_alpha; divergent = divergent || final_result.divergent; + non_reversible = non_reversible || final_result.non_reversible; + max_rev_diff = std::max(max_rev_diff, final_result.max_rev_diff); // Multinomial sampling from the combined subtree double denom = static_cast(n_prime + n_double_prime); @@ -246,6 +295,8 @@ BuildTreeResult build_tree( result.alpha = alpha_prime; result.n_alpha = n_alpha_prime; result.divergent = divergent; + result.non_reversible = non_reversible; + result.max_rev_diff = max_rev_diff; return result; } } @@ -259,11 +310,15 @@ StepResult nuts_step( SafeRNG& rng, int max_depth, const ProjectPositionFn* project_position, - const ProjectMomentumFn* project_momentum + const ProjectMomentumFn* project_momentum, + bool reverse_check, + double reverse_check_tol ) { // Create Memoizer with joint function Memoizer memo(joint); bool any_divergence = false; + bool any_non_reversible = false; + double worst_rev_diff = 0.0; arma::vec r0 = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, init_theta.n_elem); @@ -304,7 +359,8 @@ StepResult nuts_step( rho_fwd = rho; result = build_tree( theta_min, r_min, log_u, v, j, step_size, init_theta, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); theta_min = std::move(result.theta_min); r_min = std::move(result.r_min); @@ -316,7 +372,8 @@ StepResult nuts_step( rho_bck = rho; result = build_tree( theta_plus, r_plus, log_u, v, j, step_size, init_theta, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); theta_plus = std::move(result.theta_plus); r_plus = std::move(result.r_plus); @@ -327,6 +384,8 @@ StepResult nuts_step( } any_divergence = any_divergence || result.divergent; + any_non_reversible = any_non_reversible || result.non_reversible; + worst_rev_diff = std::max(worst_rev_diff, result.max_rev_diff); alpha = result.alpha; n_alpha = result.n_alpha; @@ -364,7 +423,9 @@ StepResult nuts_step( auto diag = std::make_shared(); diag->tree_depth = j; diag->divergent = any_divergence; + diag->non_reversible = any_non_reversible; diag->energy = energy; + diag->max_rev_diff = worst_rev_diff; return {theta, accept_prob, diag}; } diff --git a/src/mcmc/algorithms/nuts.h b/src/mcmc/algorithms/nuts.h index 7f78b74a..ac7e0a42 100644 --- a/src/mcmc/algorithms/nuts.h +++ b/src/mcmc/algorithms/nuts.h @@ -31,6 +31,8 @@ struct BuildTreeResult { double alpha; ///< Sum of acceptance probabilities in the subtree int n_alpha; ///< Number of proposals contributing to alpha bool divergent; ///< Whether this subtree diverged + bool non_reversible; ///< Whether a non-reversible step was detected + double max_rev_diff; ///< Worst-case reversibility max_diff in this subtree }; @@ -53,6 +55,8 @@ struct BuildTreeResult { * @param max_depth Maximum tree depth (default = 10) * @param project_position SHAKE position projection (nullptr for unconstrained) * @param project_momentum RATTLE momentum projection (nullptr for unconstrained) + * @param reverse_check Enable runtime reversibility check (constrained only) + * @param reverse_check_tol Factor for eps²-scaled reversibility tolerance * @return StepResult with position, acceptance probability, and NUTS diagnostics */ StepResult nuts_step( @@ -63,5 +67,7 @@ StepResult nuts_step( SafeRNG& rng, int max_depth = 10, const ProjectPositionFn* project_position = nullptr, - const ProjectMomentumFn* project_momentum = nullptr + const ProjectMomentumFn* project_momentum = nullptr, + bool reverse_check = true, + double reverse_check_tol = kReverseCheckFactor ); diff --git a/src/mcmc/execution/chain_result.h b/src/mcmc/execution/chain_result.h index ae1b1e3c..d03dd5ed 100644 --- a/src/mcmc/execution/chain_result.h +++ b/src/mcmc/execution/chain_result.h @@ -41,6 +41,8 @@ class ChainResult { arma::ivec treedepth_samples; /// NUTS/HMC divergent transition flags (n_iter). arma::ivec divergent_samples; + /// NUTS/HMC non-reversible step flags (n_iter). + arma::ivec non_reversible_samples; /// NUTS/HMC energy diagnostic (n_iter). arma::vec energy_samples; /// Whether NUTS/HMC diagnostics are stored. @@ -82,6 +84,7 @@ class ChainResult { void reserve_nuts_diagnostics(const size_t n_iter) { treedepth_samples.set_size(n_iter); divergent_samples.set_size(n_iter); + non_reversible_samples.set_size(n_iter); energy_samples.set_size(n_iter); has_nuts_diagnostics = true; } @@ -120,9 +123,10 @@ class ChainResult { * @param divergent Whether a divergence occurred * @param energy Final Hamiltonian energy */ - void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, double energy) { + void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, bool non_reversible, double energy) { treedepth_samples(iter) = tree_depth; divergent_samples(iter) = divergent ? 1 : 0; + non_reversible_samples(iter) = non_reversible ? 1 : 0; energy_samples(iter) = energy; } }; diff --git a/src/mcmc/execution/chain_runner.cpp b/src/mcmc/execution/chain_runner.cpp index 442f5abc..1aef3f4a 100644 --- a/src/mcmc/execution/chain_runner.cpp +++ b/src/mcmc/execution/chain_runner.cpp @@ -87,7 +87,7 @@ void run_mcmc_chain( if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { auto* diag = dynamic_cast(result.diagnostics.get()); if (diag) { - chain_result.store_nuts_diagnostics(sample_index, diag->tree_depth, diag->divergent, diag->energy); + chain_result.store_nuts_diagnostics(sample_index, diag->tree_depth, diag->divergent, diag->non_reversible, diag->energy); } } @@ -219,6 +219,7 @@ Rcpp::List convert_results_to_list(const std::vector& results) { if (chain.has_nuts_diagnostics) { chain_list["treedepth"] = chain.treedepth_samples; chain_list["divergent"] = chain.divergent_samples; + chain_list["non_reversible"] = chain.non_reversible_samples; chain_list["energy"] = chain.energy_samples; } } diff --git a/src/mcmc/execution/sampler_config.h b/src/mcmc/execution/sampler_config.h index 595c819a..2305a194 100644 --- a/src/mcmc/execution/sampler_config.h +++ b/src/mcmc/execution/sampler_config.h @@ -35,6 +35,12 @@ struct SamplerConfig { /// Enable missing-data imputation during sampling. bool na_impute = false; + /// Enable runtime reversibility check for constrained integration. + bool reverse_check = true; + + /// Factor for the eps²-scaled reversibility tolerance (tol = factor * eps²). + double reverse_check_tol = 0.5; + /// Random seed. int seed = 42; diff --git a/src/mcmc/execution/step_result.h b/src/mcmc/execution/step_result.h index 4768769f..490ed289 100644 --- a/src/mcmc/execution/step_result.h +++ b/src/mcmc/execution/step_result.h @@ -25,9 +25,11 @@ struct DiagnosticsBase { * NUTSDiagnostics - Per-iteration NUTS diagnostics (derives from DiagnosticsBase) */ struct NUTSDiagnostics : public DiagnosticsBase { - int tree_depth; ///< Depth of the trajectory tree - bool divergent; ///< Whether a divergence occurred - double energy; ///< Final Hamiltonian (-log posterior + kinetic energy) + int tree_depth; ///< Depth of the trajectory tree + bool divergent; ///< Whether a divergence occurred + bool non_reversible; ///< Whether a non-reversible constrained step occurred + double energy; ///< Final Hamiltonian (-log posterior + kinetic energy) + double max_rev_diff; ///< Worst-case reversibility max_diff across the tree }; diff --git a/src/mcmc/samplers/hmc_sampler.h b/src/mcmc/samplers/hmc_sampler.h index 1a6d7fb4..a4812bba 100644 --- a/src/mcmc/samplers/hmc_sampler.h +++ b/src/mcmc/samplers/hmc_sampler.h @@ -19,7 +19,9 @@ class HMCSampler : public GradientSamplerBase { public: explicit HMCSampler(const SamplerConfig& config, WarmupSchedule& schedule) : GradientSamplerBase(config.initial_step_size, config.target_acceptance, schedule), - num_leapfrogs_(config.num_leapfrogs) + num_leapfrogs_(config.num_leapfrogs), + reverse_check_(config.reverse_check), + reverse_check_tol_(config.reverse_check_tol) {} protected: @@ -74,11 +76,15 @@ class HMCSampler : public GradientSamplerBase { StepResult result = hmc_step( x, step_size_, joint_fn, num_leapfrogs_, inv_mass, - proj_pos, proj_mom, rng); + proj_pos, proj_mom, rng, + reverse_check_ && enforce_reverse_check_, + reverse_check_tol_); model.set_full_position(result.state); return result; } int num_leapfrogs_; + bool reverse_check_; + double reverse_check_tol_; }; diff --git a/src/mcmc/samplers/nuts_sampler.h b/src/mcmc/samplers/nuts_sampler.h index f73c7e97..4be8e2aa 100644 --- a/src/mcmc/samplers/nuts_sampler.h +++ b/src/mcmc/samplers/nuts_sampler.h @@ -19,7 +19,9 @@ class NUTSSampler : public GradientSamplerBase { public: explicit NUTSSampler(const SamplerConfig& config, WarmupSchedule& schedule) : GradientSamplerBase(config.initial_step_size, config.target_acceptance, schedule), - max_tree_depth_(config.max_tree_depth) + max_tree_depth_(config.max_tree_depth), + reverse_check_(config.reverse_check), + reverse_check_tol_(config.reverse_check_tol) {} bool has_nuts_diagnostics() const override { return true; } @@ -76,7 +78,9 @@ class NUTSSampler : public GradientSamplerBase { StepResult result = nuts_step( x, step_size_, joint_fn, inv_mass, rng, max_tree_depth_, - &proj_pos, &proj_mom + &proj_pos, &proj_mom, + reverse_check_ && enforce_reverse_check_, + reverse_check_tol_ ); model.set_full_position(result.state); @@ -84,4 +88,6 @@ class NUTSSampler : public GradientSamplerBase { } int max_tree_depth_; + bool reverse_check_; + double reverse_check_tol_; }; diff --git a/src/mcmc/samplers/sampler_base.h b/src/mcmc/samplers/sampler_base.h index efe5a03e..05f75ac0 100644 --- a/src/mcmc/samplers/sampler_base.h +++ b/src/mcmc/samplers/sampler_base.h @@ -114,6 +114,11 @@ class GradientSamplerBase : public SamplerBase { // Use adaptation controller's current step size for this iteration step_size_ = adapt_->current_step_size(); + // Phase-aware reverse check: observe during warmup, enforce during sampling. + // The check always runs (recording max_diff and non_reversible counts), + // but only terminates trees / rejects steps when enforcing. + enforce_reverse_check_ = schedule_.sampling(iteration); + StepResult result = do_gradient_step(model); // Let the adaptation controller handle step-size and mass-matrix logic. @@ -186,6 +191,10 @@ class GradientSamplerBase : public SamplerBase { double step_size_; double target_acceptance_; + /// Whether the reverse check should enforce (reject) this iteration. + /// Set in step() before do_gradient_step(). Derived classes read this. + bool enforce_reverse_check_ = false; + public: /** * Initialize the adaptation controller and run the step-size heuristic. diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 48ef75ea..dbbf1283 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -35,7 +35,8 @@ Rcpp::List sample_ggm( const double target_acceptance = 0.8, const int max_tree_depth = 10, const bool na_impute = false, - const Rcpp::Nullable missing_index_nullable = R_NilValue + const Rcpp::Nullable missing_index_nullable = R_NilValue, + const bool reverse_check = true ) { // Create model from R input @@ -62,6 +63,7 @@ Rcpp::List sample_ggm( config.target_acceptance = target_acceptance; config.max_tree_depth = max_tree_depth; config.na_impute = na_impute; + config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/src/sample_mixed.cpp b/src/sample_mixed.cpp index e93b6a2e..38d9cb7a 100644 --- a/src/sample_mixed.cpp +++ b/src/sample_mixed.cpp @@ -75,7 +75,8 @@ Rcpp::List sample_mixed_mrf( const int num_leapfrogs = 100, const bool na_impute = false, const Rcpp::Nullable missing_index_discrete_nullable = R_NilValue, - const Rcpp::Nullable missing_index_continuous_nullable = R_NilValue + const Rcpp::Nullable missing_index_continuous_nullable = R_NilValue, + const bool reverse_check = true ) { // Extract model inputs from R list arma::imat discrete_obs = Rcpp::as(inputFromR["discrete_observations"]); @@ -132,6 +133,7 @@ Rcpp::List sample_mixed_mrf( config.target_acceptance = target_acceptance; config.max_tree_depth = max_tree_depth; config.num_leapfrogs = num_leapfrogs; + config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index f2b96bcc..250ad2c1 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -66,7 +66,8 @@ Rcpp::List sample_omrf( const double target_acceptance = 0.8, const int max_tree_depth = 10, const int num_leapfrogs = 10, - const Rcpp::Nullable pairwise_scaling_factors_nullable = R_NilValue + const Rcpp::Nullable pairwise_scaling_factors_nullable = R_NilValue, + const bool reverse_check = true ) { // Create model from R input OMRFModel model = createOMRFModelFromR( @@ -106,6 +107,7 @@ Rcpp::List sample_omrf( config.max_tree_depth = max_tree_depth; config.num_leapfrogs = num_leapfrogs; config.na_impute = na_impute; + config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/tests/testthat/test-fit-object-contract.R b/tests/testthat/test-fit-object-contract.R index 1c72d16e..fa6ced14 100644 --- a/tests/testthat/test-fit-object-contract.R +++ b/tests/testthat/test-fit-object-contract.R @@ -319,7 +319,8 @@ for(spec in get_bgms_fixtures()) { # (This is safe because the fixture cache returns the same object.) # Instead, just verify that accessing a summary field returns data. summary_fields = grep( - "^posterior_summary_", names(fit), value = TRUE + "^posterior_summary_", names(fit), + value = TRUE ) for(field in summary_fields) { val = fit[[field]] @@ -345,7 +346,8 @@ for(spec in get_bgmcompare_fixtures()) { { fit = spec$get_fit() summary_fields = grep( - "^posterior_summary_", names(fit), value = TRUE + "^posterior_summary_", names(fit), + value = TRUE ) for(field in summary_fields) { val = fit[[field]] diff --git a/tests/testthat/test-reversibility-check.R b/tests/testthat/test-reversibility-check.R new file mode 100644 index 00000000..a4d86d0c --- /dev/null +++ b/tests/testthat/test-reversibility-check.R @@ -0,0 +1,155 @@ +# --------------------------------------------------------------------------- # +# Reversibility check — runtime forward-backward round-trip for RATTLE. +# +# Tests verify: +# 1. Normal steps pass the reversibility check (reversible = TRUE) +# 2. A strict tolerance triggers non-reversible detections +# 3. The non_reversible diagnostic flows through to bgm() output +# --------------------------------------------------------------------------- # + + +# ---- Helpers (reused from test-rattle-leapfrog.R) --------------------------- + +make_test_phi_rc = function(p, seed = 42) { + set.seed(seed) + A = matrix(rnorm(p * p), p, p) + K = A %*% t(A) + diag(p) + Phi = chol(K) + list(Phi = Phi, K = K, p = p) +} + +phi_to_full_position_rc = function(Phi) { + p = nrow(Phi) + x = numeric(p * (p + 1) / 2) + idx = 1 + for(q in seq_len(p)) { + if(q > 1) { + for(i in seq_len(q - 1)) { + x[idx] = Phi[i, q] + idx = idx + 1 + } + } + x[idx] = log(Phi[q, q]) + idx = idx + 1 + } + x +} + +make_scenario_rc = function(p, edges, seed) { + dat = make_test_phi_rc(p, seed = seed) + n = 10 + set.seed(seed + 1000) + X = matrix(rnorm(n * p), n, p) + S = t(X) %*% X + scale = 2.5 + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(seed + 2000) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + list(x0 = x0, r0 = r0, S = S, n = n, scale = scale, p = p, edges = edges) +} + + +# ---- 1. Normal steps pass the reversibility check -------------------------- + +test_that("constrained leapfrog checked passes for well-behaved steps", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + + sc = make_scenario_rc(p, edges, seed = 300) + eps = 0.005 + n_steps = 10 + + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, + reverse_check_factor = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +test_that("checked leapfrog matches unchecked positions", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + edges[2, 4] = 0L + edges[4, 2] = 0L + + sc = make_scenario_rc(p, edges, seed = 301) + eps = 0.005 + n_steps = 8 + + checked = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale + ) + unchecked = ggm_test_leapfrog_constrained( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale + ) + + expect_equal(as.vector(checked$x), as.vector(unchecked$x), tolerance = 1e-12) + expect_equal(as.vector(checked$r), as.vector(unchecked$r), tolerance = 1e-12) +}) + + +# ---- 2. Strict tolerance triggers non-reversible detections ---------------- + +test_that("extreme tolerance detects non-reversible steps", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + edges[2, 4] = 0L + edges[4, 2] = 0L + + sc = make_scenario_rc(p, edges, seed = 302) + eps = 0.01 + n_steps = 20 + + # With an impossibly tight factor, expect some failures + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, + reverse_check_factor = 1e-11 + ) + + expect_gt(result$non_reversible_count, 0L) +}) + + +# ---- 3. Integration: non_reversible diagnostic in bgm() output ------------- + +test_that("bgm() nuts_diag includes non_reversible field", { + skip_on_cran() + + set.seed(1) + p = 4 + n = 80 + data = matrix(rnorm(n * p), n, p) + colnames(data) = paste0("V", seq_len(p)) + + fit = bgm( + data, + variable_type = "continuous", + iter = 10, + warmup = 50, + edge_selection = TRUE, + chains = 1, + display_progress = "none" + ) + + if(!is.null(fit$nuts_diag)) { + expect_true("non_reversible" %in% names(fit$nuts_diag)) + expect_true("total_non_reversible" %in% names(fit$nuts_diag$summary)) + } +}) diff --git a/tests/testthat/test-validate-sampler.R b/tests/testthat/test-validate-sampler.R index f6e39b6a..f2e63a03 100644 --- a/tests/testthat/test-validate-sampler.R +++ b/tests/testthat/test-validate-sampler.R @@ -13,6 +13,7 @@ vs = function(...) { hmc_num_leapfrogs = 100L, nuts_max_depth = 10L, learn_mass_matrix = TRUE, + reverse_check = TRUE, chains = 2L, cores = 2L, seed = 42L, @@ -388,11 +389,12 @@ test_that("FALSE → 0L (none)", { # 11. Full return structure # ============================================================================== -test_that("return list has all 11 expected elements", { +test_that("return list has all 12 expected elements", { res = vs() expected_names = c( "update_method", "target_accept", "iter", "warmup", "hmc_num_leapfrogs", "nuts_max_depth", "learn_mass_matrix", + "reverse_check", "chains", "cores", "seed", "progress_type" ) expect_named(res, expected_names) From 0ae3ce7abc7331c706b494643bd9d7a1ed9246cb Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Wed, 1 Apr 2026 23:14:44 +0200 Subject: [PATCH 2/9] refactor: remove reverse_check from public API The reversibility check is now always on (hardcoded in SamplerConfig). Removed the parameter from bgm(), bgm_spec(), validate_sampler(), run_sampler dispatchers, and the three C++ Rcpp::export entry points. Updated test-validate-sampler expected return list. --- R/RcppExports.R | 12 ++++++------ R/bgm.R | 8 -------- R/bgm_spec.R | 3 --- R/run_sampler.R | 9 +++------ R/validate_sampler.R | 2 -- man/bgm.Rd | 7 ------- src/RcppExports.cpp | 27 ++++++++++++-------------- src/sample_ggm.cpp | 4 +--- src/sample_mixed.cpp | 4 +--- src/sample_omrf.cpp | 4 +--- tests/testthat/test-validate-sampler.R | 4 +--- 11 files changed, 25 insertions(+), 59 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index b0cd6226..64ea7689 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -121,16 +121,16 @@ run_mixed_simulation_parallel <- function(mux_samples, disc_samples, muy_samples .Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, disc_samples, muy_samples, cont_samples, cross_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } -sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL, reverse_check = TRUE) { - .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, reverse_check) +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable) } -sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL, reverse_check = TRUE) { - .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable, reverse_check) +sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) { + .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable) } -sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL, reverse_check = TRUE) { - .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable, reverse_check) +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable) } compute_Vn_mfm_sbm <- function(num_variables, dirichlet_alpha, t_max, lambda) { diff --git a/R/bgm.R b/R/bgm.R index a468faf9..942a1cce 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -306,12 +306,6 @@ #' matrix during warmup (NUTS only). If \code{FALSE}, use the identity #' matrix. Default: \code{TRUE}. #' -#' @param reverse_check Logical. If \code{TRUE}, run a forward-backward -#' reversibility check after each constrained leapfrog step (RATTLE). -#' Non-reversible steps are rejected, protecting against numerical -#' integration failures. Only applies when \code{edge_selection = TRUE} -#' and the sampler uses gradient-based proposals. Default: \code{TRUE}. -#' #' @param chains Integer. Number of parallel chains to run. Default: \code{4}. #' #' @param cores Integer. Number of CPU cores for parallel execution. @@ -446,7 +440,6 @@ bgm = function( hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, - reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), @@ -521,7 +514,6 @@ bgm = function( hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - reverse_check = reverse_check, chains = chains, cores = cores, seed = seed, diff --git a/R/bgm_spec.R b/R/bgm_spec.R index 5db8d711..f33d5c78 100644 --- a/R/bgm_spec.R +++ b/R/bgm_spec.R @@ -280,7 +280,6 @@ bgm_spec = function(x, hmc_num_leapfrogs = 100L, nuts_max_depth = 10L, learn_mass_matrix = TRUE, - reverse_check = TRUE, chains = 4L, cores = parallel::detectCores(), seed = NULL, @@ -335,7 +334,6 @@ bgm_spec = function(x, hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - reverse_check = reverse_check, chains = chains, cores = cores, seed = seed, @@ -1080,7 +1078,6 @@ sampler_sublist = function(s) { hmc_num_leapfrogs = as.integer(s$hmc_num_leapfrogs), nuts_max_depth = as.integer(s$nuts_max_depth), learn_mass_matrix = s$learn_mass_matrix, - reverse_check = s$reverse_check, seed = as.integer(s$seed), progress_type = as.integer(s$progress_type) ) diff --git a/R/run_sampler.R b/R/run_sampler.R index 3beffba0..3eb7b684 100644 --- a/R/run_sampler.R +++ b/R/run_sampler.R @@ -95,8 +95,7 @@ run_sampler_ggm = function(spec) { target_acceptance = s$target_accept, max_tree_depth = s$nuts_max_depth, na_impute = m$na_impute, - missing_index_nullable = m$missing_index, - reverse_check = s$reverse_check + missing_index_nullable = m$missing_index ) out_raw @@ -153,8 +152,7 @@ run_sampler_omrf = function(spec) { target_acceptance = s$target_accept, max_tree_depth = s$nuts_max_depth, num_leapfrogs = s$hmc_num_leapfrogs, - pairwise_scaling_factors_nullable = p$pairwise_scaling_factors, - reverse_check = s$reverse_check + pairwise_scaling_factors_nullable = p$pairwise_scaling_factors ) out_raw @@ -213,8 +211,7 @@ run_sampler_mixed_mrf = function(spec) { num_leapfrogs = s$hmc_num_leapfrogs, na_impute = m$na_impute, missing_index_discrete_nullable = m$missing_index_discrete, - missing_index_continuous_nullable = m$missing_index_continuous, - reverse_check = s$reverse_check + missing_index_continuous_nullable = m$missing_index_continuous ) out_raw diff --git a/R/validate_sampler.R b/R/validate_sampler.R index 354b7bd6..19f41ab0 100644 --- a/R/validate_sampler.R +++ b/R/validate_sampler.R @@ -93,7 +93,6 @@ validate_sampler = function(update_method, hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, - reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), seed = NULL, @@ -205,7 +204,6 @@ validate_sampler = function(update_method, hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - reverse_check = reverse_check, chains = chains, cores = cores, seed = seed, diff --git a/man/bgm.Rd b/man/bgm.Rd index 917dab8d..19a38c58 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -28,7 +28,6 @@ bgm( hmc_num_leapfrogs = 100, nuts_max_depth = 10, learn_mass_matrix = TRUE, - reverse_check = TRUE, chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), @@ -138,12 +137,6 @@ Default: \code{10}.} matrix during warmup (NUTS only). If \code{FALSE}, use the identity matrix. Default: \code{TRUE}.} -\item{reverse_check}{Logical. If \code{TRUE}, run a forward-backward -reversibility check after each constrained leapfrog step (RATTLE). -Non-reversible steps are rejected, protecting against numerical -integration failures. Only applies when \code{edge_selection = TRUE} -and the sampler uses gradient-based proposals. Default: \code{TRUE}.} - \item{chains}{Integer. Number of parallel chains to run. Default: \code{4}.} \item{cores}{Integer. Number of CPU cores for parallel execution. diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6c00da84..f3883eb6 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -536,8 +536,8 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const bool reverse_check); -RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP reverse_checkSEXP) { +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -563,14 +563,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); - Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); - rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, reverse_check)); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)); return rcpp_result_gen; END_RCPP } // sample_mixed_mrf -Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable, const bool reverse_check); -RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP, SEXP reverse_checkSEXP) { +Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable); +RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -598,14 +597,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_discrete_nullable(missing_index_discrete_nullableSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_continuous_nullable(missing_index_continuous_nullableSEXP); - Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); - rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable, reverse_check)); + rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)); return rcpp_result_gen; END_RCPP } // sample_omrf -Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable, const bool reverse_check); -RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP, SEXP reverse_checkSEXP) { +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -633,8 +631,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type pairwise_scaling_factors_nullable(pairwise_scaling_factors_nullableSEXP); - Rcpp::traits::input_parameter< const bool >::type reverse_check(reverse_checkSEXP); - rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable, reverse_check)); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)); return rcpp_result_gen; END_RCPP } @@ -684,9 +681,9 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_run_ggm_simulation_parallel", (DL_FUNC) &_bgms_run_ggm_simulation_parallel, 9}, {"_bgms_sample_mixed_mrf_gibbs", (DL_FUNC) &_bgms_sample_mixed_mrf_gibbs, 11}, {"_bgms_run_mixed_simulation_parallel", (DL_FUNC) &_bgms_run_mixed_simulation_parallel, 16}, - {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 23}, - {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 25}, - {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 25}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 22}, + {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 24}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index dbbf1283..48ef75ea 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -35,8 +35,7 @@ Rcpp::List sample_ggm( const double target_acceptance = 0.8, const int max_tree_depth = 10, const bool na_impute = false, - const Rcpp::Nullable missing_index_nullable = R_NilValue, - const bool reverse_check = true + const Rcpp::Nullable missing_index_nullable = R_NilValue ) { // Create model from R input @@ -63,7 +62,6 @@ Rcpp::List sample_ggm( config.target_acceptance = target_acceptance; config.max_tree_depth = max_tree_depth; config.na_impute = na_impute; - config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/src/sample_mixed.cpp b/src/sample_mixed.cpp index 38d9cb7a..e93b6a2e 100644 --- a/src/sample_mixed.cpp +++ b/src/sample_mixed.cpp @@ -75,8 +75,7 @@ Rcpp::List sample_mixed_mrf( const int num_leapfrogs = 100, const bool na_impute = false, const Rcpp::Nullable missing_index_discrete_nullable = R_NilValue, - const Rcpp::Nullable missing_index_continuous_nullable = R_NilValue, - const bool reverse_check = true + const Rcpp::Nullable missing_index_continuous_nullable = R_NilValue ) { // Extract model inputs from R list arma::imat discrete_obs = Rcpp::as(inputFromR["discrete_observations"]); @@ -133,7 +132,6 @@ Rcpp::List sample_mixed_mrf( config.target_acceptance = target_acceptance; config.max_tree_depth = max_tree_depth; config.num_leapfrogs = num_leapfrogs; - config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index 250ad2c1..f2b96bcc 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -66,8 +66,7 @@ Rcpp::List sample_omrf( const double target_acceptance = 0.8, const int max_tree_depth = 10, const int num_leapfrogs = 10, - const Rcpp::Nullable pairwise_scaling_factors_nullable = R_NilValue, - const bool reverse_check = true + const Rcpp::Nullable pairwise_scaling_factors_nullable = R_NilValue ) { // Create model from R input OMRFModel model = createOMRFModelFromR( @@ -107,7 +106,6 @@ Rcpp::List sample_omrf( config.max_tree_depth = max_tree_depth; config.num_leapfrogs = num_leapfrogs; config.na_impute = na_impute; - config.reverse_check = reverse_check; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); diff --git a/tests/testthat/test-validate-sampler.R b/tests/testthat/test-validate-sampler.R index f2e63a03..f6e39b6a 100644 --- a/tests/testthat/test-validate-sampler.R +++ b/tests/testthat/test-validate-sampler.R @@ -13,7 +13,6 @@ vs = function(...) { hmc_num_leapfrogs = 100L, nuts_max_depth = 10L, learn_mass_matrix = TRUE, - reverse_check = TRUE, chains = 2L, cores = 2L, seed = 42L, @@ -389,12 +388,11 @@ test_that("FALSE → 0L (none)", { # 11. Full return structure # ============================================================================== -test_that("return list has all 12 expected elements", { +test_that("return list has all 11 expected elements", { res = vs() expected_names = c( "update_method", "target_accept", "iter", "warmup", "hmc_num_leapfrogs", "nuts_max_depth", "learn_mass_matrix", - "reverse_check", "chains", "cores", "seed", "progress_type" ) expect_named(res, expected_names) From 47c3f90f40a78d70358cc4afe903b60935b85262 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Wed, 1 Apr 2026 23:20:44 +0200 Subject: [PATCH 3/9] test: add stress tests for reversibility check Five new scenarios exercising the RATTLE reverse check: - Large step sizes (eps=0.5) with near-machine-epsilon tolerance - Long trajectories (200 constrained leapfrog steps) - Rank-deficient regime (n=5, p=8) - Ill-conditioned data (condition number > 1e3) - High-dimensional GGM integration (p=20, edge selection) --- tests/testthat/test-reversibility-check.R | 146 ++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/tests/testthat/test-reversibility-check.R b/tests/testthat/test-reversibility-check.R index a4d86d0c..a9da9b09 100644 --- a/tests/testthat/test-reversibility-check.R +++ b/tests/testthat/test-reversibility-check.R @@ -5,6 +5,11 @@ # 1. Normal steps pass the reversibility check (reversible = TRUE) # 2. A strict tolerance triggers non-reversible detections # 3. The non_reversible diagnostic flows through to bgm() output +# 4. Stress: large step sizes amplify reversibility error +# 5. Stress: long trajectories (200 steps) remain reversible +# 6. Stress: small n / large p rank-deficient regime +# 7. Stress: ill-conditioned data (high condition number) +# 8. Stress: high-dimensional GGM (p=20) integration test # --------------------------------------------------------------------------- # @@ -153,3 +158,144 @@ test_that("bgm() nuts_diag includes non_reversible field", { expect_true("total_non_reversible" %in% names(fit$nuts_diag$summary)) } }) + + +# ---- 4. Stress: large step sizes ------------------------------------------- + +test_that("large step sizes increase reversibility error", { + p = 5 + edges = matrix(1L, p, p) + diag(edges) = 0L + + sc = make_scenario_rc(p, edges, seed = 400) + + # Small step: should be reversible + small = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.005, 10, sc$S, sc$n, edges, sc$scale, + reverse_check_factor = 0.5 + ) + expect_equal(small$non_reversible_count, 0L) + + # Large step (100x): near-machine-epsilon tolerance catches round-trip drift + large = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.5, 20, sc$S, sc$n, edges, sc$scale, + reverse_check_factor = 1e-15 + ) + expect_gt(large$non_reversible_count, 0L) +}) + + +# ---- 5. Stress: long trajectories (200 steps) ------------------------------ + +test_that("long trajectory (200 steps) stays reversible at normal tolerance", { + p = 5 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 4] = 0L + edges[4, 1] = 0L + + sc = make_scenario_rc(p, edges, seed = 500) + + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.005, 200, sc$S, sc$n, edges, sc$scale, + reverse_check_factor = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 6. Stress: small n / large p (rank-deficient S) ----------------------- + +test_that("rank-deficient regime (n < p) stays reversible", { + p = 8 + edges = matrix(1L, p, p) + diag(edges) = 0L + + # Build scenario with n < p (rank-deficient sufficient statistic) + dat = make_test_phi_rc(p, seed = 600) + n = 5 + set.seed(601) + X = matrix(rnorm(n * p), n, p) + S = t(X) %*% X # rank = min(n, p) = 5 < 8 + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(602) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + result = ggm_test_leapfrog_constrained_checked( + x0, r0, 0.005, 30, S, n, edges, 2.5, + reverse_check_factor = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 7. Stress: ill-conditioned data (high condition number) ---------------- + +test_that("ill-conditioned data (kappa ~ 1e4) stays reversible", { + p = 6 + edges = matrix(1L, p, p) + diag(edges) = 0L + + # Build ill-conditioned sufficient statistic + dat = make_test_phi_rc(p, seed = 700) + n = 50 + set.seed(701) + X = matrix(rnorm(n * p), n, p) + # Stretch first column to create high condition number + X[, 1] = X[, 1] * 100 + S = t(X) %*% X + # Verify ill-conditioning + ev = eigen(S, symmetric = TRUE, only.values = TRUE)$values + kappa = max(ev) / min(ev) + expect_gt(kappa, 1e3) + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(702) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + result = ggm_test_leapfrog_constrained_checked( + x0, r0, 0.003, 30, S, n, edges, 2.5, + reverse_check_factor = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 8. Stress: high-dimensional GGM (p=20) integration -------------------- + +test_that("high-dimensional GGM (p=20) completes without crash", { + skip_on_cran() + + set.seed(800) + p = 20 + n = 100 + data = matrix(rnorm(n * p), n, p) + colnames(data) = paste0("V", seq_len(p)) + + fit = bgm( + data, + variable_type = "continuous", + iter = 10, + warmup = 50, + edge_selection = TRUE, + chains = 1, + display_progress = "none" + ) + + expect_s3_class(fit, "bgms") + if(!is.null(fit$nuts_diag)) { + expect_true("non_reversible" %in% names(fit$nuts_diag)) + } +}) From f62b46230b35ed169a53faf8ecec141b79164b04 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:12 +0200 Subject: [PATCH 4/9] =?UTF-8?q?refactor:=20unify=20reverse=5Fcheck=5Ffacto?= =?UTF-8?q?r=20=E2=86=92=20reverse=5Fcheck=5Ftol,=20remove=20kReverseCheck?= =?UTF-8?q?Factor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename reverse_check_factor to reverse_check_tol in leapfrog layer to match HMC/NUTS/config naming. Remove kReverseCheckFactor constant; the default 0.5 lives in sampler_config.h and function signatures. --- R/RcppExports.R | 4 ++-- src/RcppExports.cpp | 8 ++++---- src/ggm_gradient_interface.cpp | 4 ++-- src/mcmc/algorithms/hmc.h | 2 +- src/mcmc/algorithms/leapfrog.cpp | 6 +++--- src/mcmc/algorithms/leapfrog.h | 29 ++--------------------------- src/mcmc/algorithms/nuts.h | 3 +-- src/mcmc/execution/step_result.h | 1 - 8 files changed, 15 insertions(+), 42 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 64ea7689..9cfbc936 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -45,8 +45,8 @@ ggm_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, suf_stat, .Call(`_bgms_ggm_test_leapfrog_constrained`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in) } -ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor = 0.5, inv_mass_in = NULL) { - .Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in) +ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol = 0.5, inv_mass_in = NULL) { + .Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol, inv_mass_in) } .compute_ess_cpp <- function(array3d) { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f3883eb6..28a6f1fa 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -192,8 +192,8 @@ BEGIN_RCPP END_RCPP } // ggm_test_leapfrog_constrained_checked -Rcpp::List ggm_test_leapfrog_constrained_checked(const arma::vec& x0, const arma::vec& r0, double step_size, int n_steps, const arma::mat& suf_stat, int n, const arma::imat& edge_indicators, double pairwise_scale, double reverse_check_factor, Rcpp::Nullable inv_mass_in); -RcppExport SEXP _bgms_ggm_test_leapfrog_constrained_checked(SEXP x0SEXP, SEXP r0SEXP, SEXP step_sizeSEXP, SEXP n_stepsSEXP, SEXP suf_statSEXP, SEXP nSEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP reverse_check_factorSEXP, SEXP inv_mass_inSEXP) { +Rcpp::List ggm_test_leapfrog_constrained_checked(const arma::vec& x0, const arma::vec& r0, double step_size, int n_steps, const arma::mat& suf_stat, int n, const arma::imat& edge_indicators, double pairwise_scale, double reverse_check_tol, Rcpp::Nullable inv_mass_in); +RcppExport SEXP _bgms_ggm_test_leapfrog_constrained_checked(SEXP x0SEXP, SEXP r0SEXP, SEXP step_sizeSEXP, SEXP n_stepsSEXP, SEXP suf_statSEXP, SEXP nSEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP reverse_check_tolSEXP, SEXP inv_mass_inSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -205,9 +205,9 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type n(nSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type edge_indicators(edge_indicatorsSEXP); Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP); - Rcpp::traits::input_parameter< double >::type reverse_check_factor(reverse_check_factorSEXP); + Rcpp::traits::input_parameter< double >::type reverse_check_tol(reverse_check_tolSEXP); Rcpp::traits::input_parameter< Rcpp::Nullable >::type inv_mass_in(inv_mass_inSEXP); - rcpp_result_gen = Rcpp::wrap(ggm_test_leapfrog_constrained_checked(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in)); + rcpp_result_gen = Rcpp::wrap(ggm_test_leapfrog_constrained_checked(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol, inv_mass_in)); return rcpp_result_gen; END_RCPP } diff --git a/src/ggm_gradient_interface.cpp b/src/ggm_gradient_interface.cpp index 0a95f880..7cec08bd 100644 --- a/src/ggm_gradient_interface.cpp +++ b/src/ggm_gradient_interface.cpp @@ -263,7 +263,7 @@ Rcpp::List ggm_test_leapfrog_constrained_checked( int n, const arma::imat& edge_indicators, double pairwise_scale, - double reverse_check_factor = 0.5, + double reverse_check_tol = 0.5, Rcpp::Nullable inv_mass_in = R_NilValue) { size_t p = edge_indicators.n_rows; @@ -299,7 +299,7 @@ Rcpp::List ggm_test_leapfrog_constrained_checked( for (int s = 0; s < n_steps; ++s) { auto result = leapfrog_constrained_checked( x, r, step_size, memo, inv_mass, proj_pos, proj_mom, - reverse_check_factor + reverse_check_tol ); x = std::move(result.theta); r = std::move(result.r); diff --git a/src/mcmc/algorithms/hmc.h b/src/mcmc/algorithms/hmc.h index 099aa8f3..7b09c10e 100644 --- a/src/mcmc/algorithms/hmc.h +++ b/src/mcmc/algorithms/hmc.h @@ -164,5 +164,5 @@ StepResult hmc_step( const ProjectMomentumFn& project_momentum, SafeRNG& rng, bool reverse_check = true, - double reverse_check_tol = kReverseCheckFactor + double reverse_check_tol = 0.5 ); diff --git a/src/mcmc/algorithms/leapfrog.cpp b/src/mcmc/algorithms/leapfrog.cpp index 65a4c306..69ff8f3d 100644 --- a/src/mcmc/algorithms/leapfrog.cpp +++ b/src/mcmc/algorithms/leapfrog.cpp @@ -78,7 +78,7 @@ ConstrainedLeapfrogResult leapfrog_constrained_checked( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - double reverse_check_factor + double reverse_check_tol ) { // --- Forward step --- auto [theta_new, r_new] = leapfrog_constrained( @@ -97,10 +97,10 @@ ConstrainedLeapfrogResult leapfrog_constrained_checked( // --- Reversibility check (eps^2-scaled max-norm) --- double max_diff = arma::max(arma::abs(theta_back - theta)); - double tol = reverse_check_factor * eps * eps; + double tol = reverse_check_tol * eps * eps; bool reversible = (max_diff <= tol); - return {std::move(theta_new), std::move(r_new), reversible, max_diff}; + return {std::move(theta_new), std::move(r_new), reversible}; } diff --git a/src/mcmc/algorithms/leapfrog.h b/src/mcmc/algorithms/leapfrog.h index 45284591..e58f7d79 100644 --- a/src/mcmc/algorithms/leapfrog.h +++ b/src/mcmc/algorithms/leapfrog.h @@ -167,30 +167,6 @@ std::pair leapfrog_constrained( // Constrained leapfrog with runtime reversibility check // --------------------------------------------------------------------------- -/** - * Default factor for the epsilon-squared-scaled reversibility check. - * - * After each constrained leapfrog step the integrator takes a backward - * step and verifies that the position returns to within - * factor * eps^2 - * of the starting point in max-norm. - * - * The SHAKE direct solver's column-coupling produces O(eps^2) round-trip - * errors with a dimensionless coupling constant C in [0.001, 0.13] - * (measured on Wenchuan GGM, 18 variables, across step sizes 0.001-3.4). - * A factor of 0.5 provides generous headroom above the observed maximum - * C ~ 0.13 while catching genuine failures (divergent projection, - * near-singular Jacobian) where C >> 1. - * - * Using eps^2-scaling instead of an absolute tolerance avoids creating - * a hard ceiling on the adapted step size during warmup Stage 3c, - * which would otherwise cause catastrophic slowdown. - * - * See Zappa, Holmes-Cerfon & Goodman (2018). - */ -constexpr double kReverseCheckFactor = 0.5; - - /** * Result of a constrained leapfrog step with reversibility information. */ @@ -198,7 +174,6 @@ struct ConstrainedLeapfrogResult { arma::vec theta; ///< Updated position arma::vec r; ///< Updated momentum bool reversible; ///< Whether the forward-backward check passed - double max_diff; ///< Max-norm of forward-backward position difference }; @@ -222,7 +197,7 @@ struct ConstrainedLeapfrogResult { * @param inv_mass_diag Diagonal of the inverse mass matrix * @param project_position SHAKE position projection callback * @param project_momentum RATTLE momentum projection callback - * @param reverse_check_factor Factor for eps^2-scaled tolerance + * @param reverse_check_tol Factor for eps^2-scaled tolerance * @return ConstrainedLeapfrogResult with position, momentum, and reversibility flag */ ConstrainedLeapfrogResult leapfrog_constrained_checked( @@ -233,7 +208,7 @@ ConstrainedLeapfrogResult leapfrog_constrained_checked( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - double reverse_check_factor = kReverseCheckFactor + double reverse_check_tol ); diff --git a/src/mcmc/algorithms/nuts.h b/src/mcmc/algorithms/nuts.h index ac7e0a42..be89f21b 100644 --- a/src/mcmc/algorithms/nuts.h +++ b/src/mcmc/algorithms/nuts.h @@ -32,7 +32,6 @@ struct BuildTreeResult { int n_alpha; ///< Number of proposals contributing to alpha bool divergent; ///< Whether this subtree diverged bool non_reversible; ///< Whether a non-reversible step was detected - double max_rev_diff; ///< Worst-case reversibility max_diff in this subtree }; @@ -69,5 +68,5 @@ StepResult nuts_step( const ProjectPositionFn* project_position = nullptr, const ProjectMomentumFn* project_momentum = nullptr, bool reverse_check = true, - double reverse_check_tol = kReverseCheckFactor + double reverse_check_tol = 0.5 ); diff --git a/src/mcmc/execution/step_result.h b/src/mcmc/execution/step_result.h index 490ed289..6e3b986e 100644 --- a/src/mcmc/execution/step_result.h +++ b/src/mcmc/execution/step_result.h @@ -29,7 +29,6 @@ struct NUTSDiagnostics : public DiagnosticsBase { bool divergent; ///< Whether a divergence occurred bool non_reversible; ///< Whether a non-reversible constrained step occurred double energy; ///< Final Hamiltonian (-log posterior + kinetic energy) - double max_rev_diff; ///< Worst-case reversibility max_diff across the tree }; From c616507ab5f839de6c5d4f7b5e55e593d4acde71 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:24 +0200 Subject: [PATCH 5/9] refactor: remove max_rev_diff and max_diff propagation Remove max_rev_diff from BuildTreeResult, NUTSDiagnostics, and all NUTS tree-building code. Remove max_diff from ConstrainedLeapfrogResult. The binary non_reversible flag is sufficient for diagnostics. --- src/mcmc/algorithms/nuts.cpp | 15 ++------------- src/mcmc/samplers/sampler_base.h | 2 +- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/mcmc/algorithms/nuts.cpp b/src/mcmc/algorithms/nuts.cpp index 6e68c7c8..ea51786a 100644 --- a/src/mcmc/algorithms/nuts.cpp +++ b/src/mcmc/algorithms/nuts.cpp @@ -75,9 +75,8 @@ BuildTreeResult build_tree( // Base case: take a single leapfrog step arma::vec theta_new, r_new; bool non_reversible = false; - double step_max_diff = 0.0; if (project_position && project_momentum) { - // Always run the checked variant so we can observe max_diff. + // Always run the checked variant so we can observe reversibility. // The reverse_check flag controls whether we ACT on the result // (i.e. terminate the tree). Observation is always on. auto checked = leapfrog_constrained_checked( @@ -87,7 +86,6 @@ BuildTreeResult build_tree( ); theta_new = std::move(checked.theta); r_new = std::move(checked.r); - step_max_diff = checked.max_diff; non_reversible = !checked.reversible; } else { std::tie(theta_new, r_new) = leapfrog_memo( @@ -96,7 +94,7 @@ BuildTreeResult build_tree( } // Non-reversible step terminates the tree only when reverse_check is on. - // When off (or during warmup observation mode), we record but don't act. + // During warmup, we record but don't act. if (reverse_check && non_reversible) { arma::vec p_sharp = inv_mass_diag % r_new; BuildTreeResult result; @@ -117,7 +115,6 @@ BuildTreeResult build_tree( result.n_alpha = 1; result.divergent = false; result.non_reversible = true; - result.max_rev_diff = step_max_diff; return result; } @@ -149,7 +146,6 @@ BuildTreeResult build_tree( result.n_alpha = 1; result.divergent = divergent; result.non_reversible = non_reversible; // record even when not acting - result.max_rev_diff = step_max_diff; return result; } else { @@ -167,7 +163,6 @@ BuildTreeResult build_tree( bool divergent = init_result.divergent; bool non_reversible = init_result.non_reversible; - double max_rev_diff = init_result.max_rev_diff; // Extract values from init subtree (move — init_result not used again) arma::vec theta_min = std::move(init_result.theta_min); @@ -228,7 +223,6 @@ BuildTreeResult build_tree( result.n_alpha = n_alpha_prime + final_result.n_alpha; result.divergent = divergent || final_result.divergent; result.non_reversible = non_reversible || final_result.non_reversible; - result.max_rev_diff = std::max(max_rev_diff, final_result.max_rev_diff); return result; } @@ -243,7 +237,6 @@ BuildTreeResult build_tree( int n_alpha_double_prime = final_result.n_alpha; divergent = divergent || final_result.divergent; non_reversible = non_reversible || final_result.non_reversible; - max_rev_diff = std::max(max_rev_diff, final_result.max_rev_diff); // Multinomial sampling from the combined subtree double denom = static_cast(n_prime + n_double_prime); @@ -296,7 +289,6 @@ BuildTreeResult build_tree( result.n_alpha = n_alpha_prime; result.divergent = divergent; result.non_reversible = non_reversible; - result.max_rev_diff = max_rev_diff; return result; } } @@ -318,7 +310,6 @@ StepResult nuts_step( Memoizer memo(joint); bool any_divergence = false; bool any_non_reversible = false; - double worst_rev_diff = 0.0; arma::vec r0 = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, init_theta.n_elem); @@ -385,7 +376,6 @@ StepResult nuts_step( any_divergence = any_divergence || result.divergent; any_non_reversible = any_non_reversible || result.non_reversible; - worst_rev_diff = std::max(worst_rev_diff, result.max_rev_diff); alpha = result.alpha; n_alpha = result.n_alpha; @@ -425,7 +415,6 @@ StepResult nuts_step( diag->divergent = any_divergence; diag->non_reversible = any_non_reversible; diag->energy = energy; - diag->max_rev_diff = worst_rev_diff; return {theta, accept_prob, diag}; } diff --git a/src/mcmc/samplers/sampler_base.h b/src/mcmc/samplers/sampler_base.h index 05f75ac0..09c0f9c3 100644 --- a/src/mcmc/samplers/sampler_base.h +++ b/src/mcmc/samplers/sampler_base.h @@ -115,7 +115,7 @@ class GradientSamplerBase : public SamplerBase { step_size_ = adapt_->current_step_size(); // Phase-aware reverse check: observe during warmup, enforce during sampling. - // The check always runs (recording max_diff and non_reversible counts), + // The check always runs (recording non_reversible counts), // but only terminates trees / rejects steps when enforcing. enforce_reverse_check_ = schedule_.sampling(iteration); From 0a4f47357e08fb6746cfaa7d353824d4fecf6739 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:37 +0200 Subject: [PATCH 6/9] fix: add missing #include to mixed_mrf_metropolis.cpp The file uses std::move extensively but relied on transitive inclusion through RcppArmadillo.h, which is not guaranteed by the standard. --- src/models/mixed/mixed_mrf_metropolis.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/mixed/mixed_mrf_metropolis.cpp b/src/models/mixed/mixed_mrf_metropolis.cpp index 74c9ea8a..7901bf92 100644 --- a/src/models/mixed/mixed_mrf_metropolis.cpp +++ b/src/models/mixed/mixed_mrf_metropolis.cpp @@ -1,4 +1,5 @@ #include +#include #include "models/mixed/mixed_mrf_model.h" #include "rng/rng_utils.h" #include "mcmc/execution/step_result.h" From fef4546b496f89cd9d36b724b85e75d95c6f092f Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:44 +0200 Subject: [PATCH 7/9] docs: add non-reversible steps section to diagnostics vignette Document what non-reversible steps mean (constrained integrator round-trip failure for MRFs with continuous variables), when they occur, and remediation: increase warmup, then switch to MH. --- vignettes/diagnostics.Rmd | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vignettes/diagnostics.Rmd b/vignettes/diagnostics.Rmd index 5b8b2726..f6136b3c 100644 --- a/vignettes/diagnostics.Rmd +++ b/vignettes/diagnostics.Rmd @@ -156,6 +156,12 @@ If you see a large number of divergences, consider increasing `target_accept` (w NUTS builds trajectories by repeatedly doubling their length until a "U-turn" criterion is satisfied. If the trajectory frequently reaches the maximum allowed depth (`nuts_max_depth`, default 10), it suggests the sampler may benefit from longer trajectories to explore the posterior efficiently. Hitting the maximum depth occasionally is normal; hitting it on most iterations may indicate challenging posterior geometry. If this happens, consider increasing `nuts_max_depth`. +## Non-reversible steps + +For MRFs with continuous variables, the leapfrog integrator enforces equality constraints through a projection step. After each forward step, the integrator checks whether reversing the step returns to the starting point. When the round-trip error exceeds a tolerance scaled by the square of the step size, the step is flagged as non-reversible. + +A small number of non-reversible steps is not a concern. A large number indicates that the step size is too large for the constraint geometry. Because the step size is tuned during warmup, the most effective remedy is to increase `warmup` so the adapter has more time to find an appropriate step size. If non-reversible steps persist after increasing warmup, switch to `update_method = "adaptive-metropolis"`. + ## Warmup and equilibration Standard HMC/NUTS warmup is designed to tune the step size and mass matrix for the continuous parameters. In models with edge selection, the discrete graph structure may take longer to reach its stationary distribution than the continuous parameters. As a result, even after warmup completes, the first portion of the sampling phase may still show transient behavior (i.e., non-stationarity). From 409fc7e00373f2148e7421472df2bbc0800ce570 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:51 +0200 Subject: [PATCH 8/9] =?UTF-8?q?fix:=20rename=20reverse=5Fcheck=5Ffactor=20?= =?UTF-8?q?=E2=86=92=20reverse=5Fcheck=5Ftol=20in=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The R wrapper was renamed but 7 test calls still used the old name, causing the argument to be silently ignored. Tests with extreme tolerances (1e-11, 1e-15) were passing by accident using the default. --- tests/testthat/test-reversibility-check.R | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test-reversibility-check.R b/tests/testthat/test-reversibility-check.R index a9da9b09..80affe3d 100644 --- a/tests/testthat/test-reversibility-check.R +++ b/tests/testthat/test-reversibility-check.R @@ -75,7 +75,7 @@ test_that("constrained leapfrog checked passes for well-behaved steps", { result = ggm_test_leapfrog_constrained_checked( sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, - reverse_check_factor = 0.5 + reverse_check_tol = 0.5 ) expect_equal(result$non_reversible_count, 0L) @@ -125,7 +125,7 @@ test_that("extreme tolerance detects non-reversible steps", { # With an impossibly tight factor, expect some failures result = ggm_test_leapfrog_constrained_checked( sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, - reverse_check_factor = 1e-11 + reverse_check_tol = 1e-11 ) expect_gt(result$non_reversible_count, 0L) @@ -172,14 +172,14 @@ test_that("large step sizes increase reversibility error", { # Small step: should be reversible small = ggm_test_leapfrog_constrained_checked( sc$x0, sc$r0, 0.005, 10, sc$S, sc$n, edges, sc$scale, - reverse_check_factor = 0.5 + reverse_check_tol = 0.5 ) expect_equal(small$non_reversible_count, 0L) # Large step (100x): near-machine-epsilon tolerance catches round-trip drift large = ggm_test_leapfrog_constrained_checked( sc$x0, sc$r0, 0.5, 20, sc$S, sc$n, edges, sc$scale, - reverse_check_factor = 1e-15 + reverse_check_tol = 1e-15 ) expect_gt(large$non_reversible_count, 0L) }) @@ -198,7 +198,7 @@ test_that("long trajectory (200 steps) stays reversible at normal tolerance", { result = ggm_test_leapfrog_constrained_checked( sc$x0, sc$r0, 0.005, 200, sc$S, sc$n, edges, sc$scale, - reverse_check_factor = 0.5 + reverse_check_tol = 0.5 ) expect_equal(result$non_reversible_count, 0L) @@ -229,7 +229,7 @@ test_that("rank-deficient regime (n < p) stays reversible", { result = ggm_test_leapfrog_constrained_checked( x0, r0, 0.005, 30, S, n, edges, 2.5, - reverse_check_factor = 0.5 + reverse_check_tol = 0.5 ) expect_equal(result$non_reversible_count, 0L) @@ -266,7 +266,7 @@ test_that("ill-conditioned data (kappa ~ 1e4) stays reversible", { result = ggm_test_leapfrog_constrained_checked( x0, r0, 0.003, 30, S, n, edges, 2.5, - reverse_check_factor = 0.5 + reverse_check_tol = 0.5 ) expect_equal(result$non_reversible_count, 0L) From 00f76f966b9e5c6a726968aa8de92dbcc869e6c7 Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:21:57 +0200 Subject: [PATCH 9/9] build: add CODE_OF_CONDUCT.md to .Rbuildignore Removes the CRAN NOTE about non-standard top-level file. --- .Rbuildignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.Rbuildignore b/.Rbuildignore index 030e91ef..af3759ab 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -17,6 +17,7 @@ # GitHub / CI ^\.github$ ^CONTRIBUTING\.md$ +^CODE_OF_CONDUCT\.md$ # Hidden config files (development only) ^\.lintr$