diff --git a/.Rbuildignore b/.Rbuildignore index 030e91ef..af3759ab 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -17,6 +17,7 @@ # GitHub / CI ^\.github$ ^CONTRIBUTING\.md$ +^CODE_OF_CONDUCT\.md$ # Hidden config files (development only) ^\.lintr$ diff --git a/R/RcppExports.R b/R/RcppExports.R index 2bd030d0..9cfbc936 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -45,6 +45,10 @@ ggm_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, suf_stat, .Call(`_bgms_ggm_test_leapfrog_constrained`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in) } +ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol = 0.5, inv_mass_in = NULL) { + .Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol, inv_mass_in) +} + .compute_ess_cpp <- function(array3d) { .Call(`_bgms_compute_ess_cpp`, array3d) } diff --git a/R/bgm_spec.R b/R/bgm_spec.R index 243547f5..f33d5c78 100644 --- a/R/bgm_spec.R +++ b/R/bgm_spec.R @@ -327,20 +327,20 @@ bgm_spec = function(x, # --- Sampler (needs is_continuous and edge_selection early) ------------------ sampler = validate_sampler( - update_method = update_method, - target_accept = target_accept, - iter = iter, - warmup = warmup, + update_method = update_method, + target_accept = target_accept, + iter = iter, + warmup = warmup, hmc_num_leapfrogs = hmc_num_leapfrogs, - nuts_max_depth = nuts_max_depth, + nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - chains = chains, - cores = cores, - seed = seed, - display_progress = display_progress, - is_continuous = is_continuous, - edge_selection = if(model_type == "compare") FALSE else edge_selection, - verbose = verbose + chains = chains, + cores = cores, + seed = seed, + display_progress = display_progress, + is_continuous = is_continuous, + edge_selection = if(model_type == "compare") FALSE else edge_selection, + verbose = verbose ) # --- Build by model type ---------------------------------------------------- diff --git a/R/build_output.R b/R/build_output.R index df2a83f8..b4622600 100644 --- a/R/build_output.R +++ b/R/build_output.R @@ -305,6 +305,7 @@ build_output_bgm = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) @@ -330,6 +331,7 @@ build_output_bgm = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) @@ -593,6 +595,7 @@ build_output_mixed_mrf = function(spec, raw) { } if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible if(!is.null(chain$energy)) res[["energy__"]] = chain$energy res }) diff --git a/R/nuts_diagnostics.R b/R/nuts_diagnostics.R index 854b2d2c..fa39e2da 100644 --- a/R/nuts_diagnostics.R +++ b/R/nuts_diagnostics.R @@ -110,11 +110,12 @@ check_warmup_complete = function(energy_mat) { # Returns: An invisible named list with: # - treedepth: Integer matrix (chains x iterations). # - divergent: Integer matrix (chains x iterations), 0/1. +# - non_reversible: Integer matrix (chains x iterations), 0/1. # - energy: Numeric matrix (chains x iterations). # - ebfmi: Numeric vector of per-chain E-BFMI values. # - warmup_check: Output of check_warmup_complete(). # - summary: List with total_divergences, max_tree_depth_hits, -# min_ebfmi, and warmup_incomplete (logical). +# min_ebfmi, total_non_reversible, and warmup_incomplete (logical). # ------------------------------------------------------------------------------ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) { nuts_chains = Filter(function(chain) { @@ -134,6 +135,12 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) divergent_mat = combine_diag("divergent__") energy_mat = combine_diag("energy__") + non_reversible_mat = if("non_reversible__" %in% names(nuts_chains[[1]])) { + combine_diag("non_reversible__") + } else { + matrix(0L, nrow = nrow(divergent_mat), ncol = ncol(divergent_mat)) + } + # E-BFMI per chain compute_ebfmi = function(energy) { mean(diff(energy)^2) / stats::var(energy) @@ -145,6 +152,7 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) # Summaries n_total = nrow(divergent_mat) * ncol(divergent_mat) total_divergences = sum(divergent_mat) + total_non_reversible = sum(non_reversible_mat) max_tree_depth_hits = sum(treedepth_mat == nuts_max_depth) min_ebfmi = min(ebfmi_per_chain) low_ebfmi_chains = which(ebfmi_per_chain < 0.2) @@ -169,6 +177,14 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) } } + if(total_non_reversible > 0) { + non_rev_rate = total_non_reversible / n_total + issues = c(issues, sprintf( + "Non-reversible steps: %d (%.3f%%) - constrained integrator round-trip failed", + total_non_reversible, 100 * non_rev_rate + )) + } + if(max_tree_depth_hits > 0) { if(depth_hit_rate > 0.01) { issues = c(issues, sprintf( @@ -203,11 +219,13 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) invisible(list( treedepth = treedepth_mat, divergent = divergent_mat, + non_reversible = non_reversible_mat, energy = energy_mat, ebfmi = ebfmi_per_chain, warmup_check = warmup_check, summary = list( total_divergences = total_divergences, + total_non_reversible = total_non_reversible, max_tree_depth_hits = max_tree_depth_hits, min_ebfmi = min_ebfmi, warmup_incomplete = any(warmup_check$warmup_incomplete) diff --git a/R/validate_sampler.R b/R/validate_sampler.R index 901de223..19f41ab0 100644 --- a/R/validate_sampler.R +++ b/R/validate_sampler.R @@ -197,16 +197,16 @@ validate_sampler = function(update_method, progress_type = progress_type_from_display_progress(display_progress) list( - update_method = update_method, - target_accept = target_accept, - iter = iter, - warmup = warmup, + update_method = update_method, + target_accept = target_accept, + iter = iter, + warmup = warmup, hmc_num_leapfrogs = hmc_num_leapfrogs, - nuts_max_depth = nuts_max_depth, + nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, - chains = chains, - cores = cores, - seed = seed, - progress_type = progress_type + chains = chains, + cores = cores, + seed = seed, + progress_type = progress_type ) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ad4e8d3c..28a6f1fa 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -191,6 +191,26 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// ggm_test_leapfrog_constrained_checked +Rcpp::List ggm_test_leapfrog_constrained_checked(const arma::vec& x0, const arma::vec& r0, double step_size, int n_steps, const arma::mat& suf_stat, int n, const arma::imat& edge_indicators, double pairwise_scale, double reverse_check_tol, Rcpp::Nullable inv_mass_in); +RcppExport SEXP _bgms_ggm_test_leapfrog_constrained_checked(SEXP x0SEXP, SEXP r0SEXP, SEXP step_sizeSEXP, SEXP n_stepsSEXP, SEXP suf_statSEXP, SEXP nSEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP reverse_check_tolSEXP, SEXP inv_mass_inSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::vec& >::type x0(x0SEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type r0(r0SEXP); + Rcpp::traits::input_parameter< double >::type step_size(step_sizeSEXP); + Rcpp::traits::input_parameter< int >::type n_steps(n_stepsSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type suf_stat(suf_statSEXP); + Rcpp::traits::input_parameter< int >::type n(nSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type edge_indicators(edge_indicatorsSEXP); + Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP); + Rcpp::traits::input_parameter< double >::type reverse_check_tol(reverse_check_tolSEXP); + Rcpp::traits::input_parameter< Rcpp::Nullable >::type inv_mass_in(inv_mass_inSEXP); + rcpp_result_gen = Rcpp::wrap(ggm_test_leapfrog_constrained_checked(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol, inv_mass_in)); + return rcpp_result_gen; +END_RCPP +} // compute_ess_cpp Rcpp::NumericVector compute_ess_cpp(Rcpp::NumericVector array3d); RcppExport SEXP _bgms_compute_ess_cpp(SEXP array3dSEXP) { @@ -642,6 +662,7 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_ggm_test_logp_and_gradient_full", (DL_FUNC) &_bgms_ggm_test_logp_and_gradient_full, 5}, {"_bgms_ggm_test_project_momentum", (DL_FUNC) &_bgms_ggm_test_project_momentum, 4}, {"_bgms_ggm_test_leapfrog_constrained", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained, 9}, + {"_bgms_ggm_test_leapfrog_constrained_checked", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained_checked, 10}, {"_bgms_compute_ess_cpp", (DL_FUNC) &_bgms_compute_ess_cpp, 1}, {"_bgms_compute_rhat_cpp", (DL_FUNC) &_bgms_compute_rhat_cpp, 1}, {"_bgms_compute_indicator_ess_cpp", (DL_FUNC) &_bgms_compute_indicator_ess_cpp, 1}, diff --git a/src/ggm_gradient_interface.cpp b/src/ggm_gradient_interface.cpp index f28e1fbe..7cec08bd 100644 --- a/src/ggm_gradient_interface.cpp +++ b/src/ggm_gradient_interface.cpp @@ -251,3 +251,66 @@ Rcpp::List ggm_test_leapfrog_constrained( Rcpp::Named("dH") = H_final - H0 ); } + + +// [[Rcpp::export]] +Rcpp::List ggm_test_leapfrog_constrained_checked( + const arma::vec& x0, + const arma::vec& r0, + double step_size, + int n_steps, + const arma::mat& suf_stat, + int n, + const arma::imat& edge_indicators, + double pairwise_scale, + double reverse_check_tol = 0.5, + Rcpp::Nullable inv_mass_in = R_NilValue) +{ + size_t p = edge_indicators.n_rows; + + arma::mat inc_prob(p, p, arma::fill::value(0.5)); + GGMModel model(static_cast(p), suf_stat, inc_prob, + edge_indicators, true, pairwise_scale); + + Memoizer::JointFn joint = [&model](const arma::vec& x) + -> std::pair { + return model.logp_and_gradient_full(x); + }; + Memoizer memo(joint); + + arma::vec inv_mass; + if(inv_mass_in.isNotNull()) { + inv_mass = Rcpp::as(inv_mass_in); + } else { + inv_mass = arma::ones(x0.n_elem); + } + + ProjectPositionFn proj_pos = [&model, &inv_mass](arma::vec& x) { + model.project_position(x, inv_mass); + }; + ProjectMomentumFn proj_mom = [&model, &inv_mass](arma::vec& r, const arma::vec& x) { + model.project_momentum(r, x, inv_mass); + }; + + arma::vec x = x0; + arma::vec r = r0; + int non_reversible_count = 0; + + for (int s = 0; s < n_steps; ++s) { + auto result = leapfrog_constrained_checked( + x, r, step_size, memo, inv_mass, proj_pos, proj_mom, + reverse_check_tol + ); + x = std::move(result.theta); + r = std::move(result.r); + if (!result.reversible) { + non_reversible_count++; + } + } + + return Rcpp::List::create( + Rcpp::Named("x") = Rcpp::wrap(x), + Rcpp::Named("r") = Rcpp::wrap(r), + Rcpp::Named("non_reversible_count") = non_reversible_count + ); +} diff --git a/src/mcmc/algorithms/hmc.cpp b/src/mcmc/algorithms/hmc.cpp index 51505f18..835308e4 100644 --- a/src/mcmc/algorithms/hmc.cpp +++ b/src/mcmc/algorithms/hmc.cpp @@ -184,7 +184,9 @@ StepResult hmc_step( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - SafeRNG& rng + SafeRNG& rng, + bool reverse_check, + double reverse_check_tol ) { Memoizer memo(joint); @@ -198,11 +200,31 @@ StepResult hmc_step( // Run num_leapfrogs constrained leapfrog steps arma::vec theta = init_theta; + bool non_reversible = false; for (int i = 0; i < num_leapfrogs; ++i) { - std::tie(theta, r) = leapfrog_constrained( - theta, r, step_size, memo, inv_mass_diag, - project_position, project_momentum - ); + if (reverse_check) { + auto checked = leapfrog_constrained_checked( + theta, r, step_size, memo, inv_mass_diag, + project_position, project_momentum, + reverse_check_tol + ); + theta = std::move(checked.theta); + r = std::move(checked.r); + if (!checked.reversible) { + non_reversible = true; + break; + } + } else { + std::tie(theta, r) = leapfrog_constrained( + theta, r, step_size, memo, inv_mass_diag, + project_position, project_momentum + ); + } + } + + // Non-reversible step forces rejection + if (non_reversible) { + return {init_theta, 0.0}; } double logp1 = memo.cached_log_post(theta); diff --git a/src/mcmc/algorithms/hmc.h b/src/mcmc/algorithms/hmc.h index 84eb949a..7b09c10e 100644 --- a/src/mcmc/algorithms/hmc.h +++ b/src/mcmc/algorithms/hmc.h @@ -150,6 +150,8 @@ StepResult hmc_step( * @param project_position SHAKE position projection callback * @param project_momentum RATTLE momentum projection callback * @param rng Thread-safe random number generator + * @param reverse_check Enable runtime reversibility check + * @param reverse_check_tol Factor for eps²-scaled reversibility tolerance * @return StepResult with accepted state and acceptance probability */ StepResult hmc_step( @@ -160,5 +162,7 @@ StepResult hmc_step( const arma::vec& inv_mass_diag, const ProjectPositionFn& project_position, const ProjectMomentumFn& project_momentum, - SafeRNG& rng + SafeRNG& rng, + bool reverse_check = true, + double reverse_check_tol = 0.5 ); diff --git a/src/mcmc/algorithms/leapfrog.cpp b/src/mcmc/algorithms/leapfrog.cpp index 809afcef..69ff8f3d 100644 --- a/src/mcmc/algorithms/leapfrog.cpp +++ b/src/mcmc/algorithms/leapfrog.cpp @@ -70,6 +70,40 @@ std::pair leapfrog_constrained( } +ConstrainedLeapfrogResult leapfrog_constrained_checked( + const arma::vec& theta, + const arma::vec& r, + double eps, + Memoizer& memo, + const arma::vec& inv_mass_diag, + const ProjectPositionFn& project_position, + const ProjectMomentumFn& project_momentum, + double reverse_check_tol +) { + // --- Forward step --- + auto [theta_new, r_new] = leapfrog_constrained( + theta, r, eps, memo, inv_mass_diag, + project_position, project_momentum + ); + + // --- Backward step (negate momentum, step forward) --- + // Use a separate Memoizer so the caller's cache stays at theta_new. + Memoizer back_memo(memo.joint_fn); + arma::vec r_back = -r_new; + auto [theta_back, r_back_out] = leapfrog_constrained( + theta_new, r_back, eps, back_memo, inv_mass_diag, + project_position, project_momentum + ); + + // --- Reversibility check (eps^2-scaled max-norm) --- + double max_diff = arma::max(arma::abs(theta_back - theta)); + double tol = reverse_check_tol * eps * eps; + bool reversible = (max_diff <= tol); + + return {std::move(theta_new), std::move(r_new), reversible}; +} + + LeapfrogJointResult leapfrog( const arma::vec& theta_init, const arma::vec& r_init, diff --git a/src/mcmc/algorithms/leapfrog.h b/src/mcmc/algorithms/leapfrog.h index 031ec383..e58f7d79 100644 --- a/src/mcmc/algorithms/leapfrog.h +++ b/src/mcmc/algorithms/leapfrog.h @@ -163,6 +163,55 @@ std::pair leapfrog_constrained( ); +// --------------------------------------------------------------------------- +// Constrained leapfrog with runtime reversibility check +// --------------------------------------------------------------------------- + +/** + * Result of a constrained leapfrog step with reversibility information. + */ +struct ConstrainedLeapfrogResult { + arma::vec theta; ///< Updated position + arma::vec r; ///< Updated momentum + bool reversible; ///< Whether the forward-backward check passed +}; + + +/** + * Constrained leapfrog step with a runtime reversibility check. + * + * Performs a forward constrained leapfrog step, then a backward step + * (negate momentum, step forward, negate again). If the round-trip + * position differs from the original by more than factor * eps^2 + * in max-norm, the step is flagged as non-reversible. + * + * This is the runtime analogue of Mici's ConstrainedLeapfrogIntegrator + * reverse check (Zappa et al., 2018; Lelièvre et al., 2019), adapted + * to use eps^2-scaled tolerance matching the O(eps^2) column-coupling + * error of the direct SHAKE solver. + * + * @param theta Current position (parameter vector) + * @param r Current momentum vector + * @param eps Step size for integration + * @param memo Memoizer caching gradient evaluations + * @param inv_mass_diag Diagonal of the inverse mass matrix + * @param project_position SHAKE position projection callback + * @param project_momentum RATTLE momentum projection callback + * @param reverse_check_tol Factor for eps^2-scaled tolerance + * @return ConstrainedLeapfrogResult with position, momentum, and reversibility flag + */ +ConstrainedLeapfrogResult leapfrog_constrained_checked( + const arma::vec& theta, + const arma::vec& r, + double eps, + Memoizer& memo, + const arma::vec& inv_mass_diag, + const ProjectPositionFn& project_position, + const ProjectMomentumFn& project_momentum, + double reverse_check_tol +); + + /** * LeapfrogJointResult - Return type for multi-step leapfrog integration. * diff --git a/src/mcmc/algorithms/nuts.cpp b/src/mcmc/algorithms/nuts.cpp index df7b8bc5..ea51786a 100644 --- a/src/mcmc/algorithms/nuts.cpp +++ b/src/mcmc/algorithms/nuts.cpp @@ -65,24 +65,59 @@ BuildTreeResult build_tree( const arma::vec& inv_mass_diag, SafeRNG& rng, const ProjectPositionFn* project_position, - const ProjectMomentumFn* project_momentum + const ProjectMomentumFn* project_momentum, + bool reverse_check, + double reverse_check_tol ) { constexpr double Delta_max = 1000.0; if (j == 0) { // Base case: take a single leapfrog step arma::vec theta_new, r_new; + bool non_reversible = false; if (project_position && project_momentum) { - std::tie(theta_new, r_new) = leapfrog_constrained( + // Always run the checked variant so we can observe reversibility. + // The reverse_check flag controls whether we ACT on the result + // (i.e. terminate the tree). Observation is always on. + auto checked = leapfrog_constrained_checked( theta, r, v * step_size, memo, inv_mass_diag, - *project_position, *project_momentum + *project_position, *project_momentum, + reverse_check_tol ); + theta_new = std::move(checked.theta); + r_new = std::move(checked.r); + non_reversible = !checked.reversible; } else { std::tie(theta_new, r_new) = leapfrog_memo( theta, r, v * step_size, memo, inv_mass_diag ); } + // Non-reversible step terminates the tree only when reverse_check is on. + // During warmup, we record but don't act. + if (reverse_check && non_reversible) { + arma::vec p_sharp = inv_mass_diag % r_new; + BuildTreeResult result; + result.theta_min = theta_new; + result.theta_plus = theta_new; + result.r_min = r_new; + result.r_plus = r_new; + result.rho = r_new; + result.p_beg = r_new; + result.p_end = r_new; + result.r_prime = std::move(r_new); + result.theta_prime = std::move(theta_new); + result.p_sharp_beg = p_sharp; + result.p_sharp_end = std::move(p_sharp); + result.n_prime = 0; + result.s_prime = 0; + result.alpha = 0.0; + result.n_alpha = 1; + result.divergent = false; + result.non_reversible = true; + return result; + } + auto logp = memo.cached_log_post(theta_new); double kin = kinetic_energy(r_new, inv_mass_diag); int n_new = 1 * (log_u <= logp - kin); @@ -110,13 +145,15 @@ BuildTreeResult build_tree( result.alpha = alpha; result.n_alpha = 1; result.divergent = divergent; + result.non_reversible = non_reversible; // record even when not acting return result; } else { // Recursion: build the first subtree BuildTreeResult init_result = build_tree( theta, r, log_u, v, j - 1, step_size, theta_0, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); if (init_result.s_prime == 0) { @@ -125,6 +162,7 @@ BuildTreeResult build_tree( } bool divergent = init_result.divergent; + bool non_reversible = init_result.non_reversible; // Extract values from init subtree (move — init_result not used again) arma::vec theta_min = std::move(init_result.theta_min); @@ -147,7 +185,8 @@ BuildTreeResult build_tree( if (v == -1) { final_result = build_tree( theta_min, r_min, log_u, v, j - 1, step_size, theta_0, r0, logp0, - kin0, memo, inv_mass_diag, rng, project_position, project_momentum + kin0, memo, inv_mass_diag, rng, project_position, project_momentum, + reverse_check, reverse_check_tol ); // Update backward boundary theta_min = std::move(final_result.theta_min); @@ -155,7 +194,8 @@ BuildTreeResult build_tree( } else { final_result = build_tree( theta_plus, r_plus, log_u, v, j - 1, step_size, theta_0, r0, logp0, - kin0, memo, inv_mass_diag, rng, project_position, project_momentum + kin0, memo, inv_mass_diag, rng, project_position, project_momentum, + reverse_check, reverse_check_tol ); // Update forward boundary theta_plus = std::move(final_result.theta_plus); @@ -182,6 +222,7 @@ BuildTreeResult build_tree( result.alpha = alpha_prime + final_result.alpha; result.n_alpha = n_alpha_prime + final_result.n_alpha; result.divergent = divergent || final_result.divergent; + result.non_reversible = non_reversible || final_result.non_reversible; return result; } @@ -195,6 +236,7 @@ BuildTreeResult build_tree( double alpha_double_prime = final_result.alpha; int n_alpha_double_prime = final_result.n_alpha; divergent = divergent || final_result.divergent; + non_reversible = non_reversible || final_result.non_reversible; // Multinomial sampling from the combined subtree double denom = static_cast(n_prime + n_double_prime); @@ -246,6 +288,7 @@ BuildTreeResult build_tree( result.alpha = alpha_prime; result.n_alpha = n_alpha_prime; result.divergent = divergent; + result.non_reversible = non_reversible; return result; } } @@ -259,11 +302,14 @@ StepResult nuts_step( SafeRNG& rng, int max_depth, const ProjectPositionFn* project_position, - const ProjectMomentumFn* project_momentum + const ProjectMomentumFn* project_momentum, + bool reverse_check, + double reverse_check_tol ) { // Create Memoizer with joint function Memoizer memo(joint); bool any_divergence = false; + bool any_non_reversible = false; arma::vec r0 = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, init_theta.n_elem); @@ -304,7 +350,8 @@ StepResult nuts_step( rho_fwd = rho; result = build_tree( theta_min, r_min, log_u, v, j, step_size, init_theta, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); theta_min = std::move(result.theta_min); r_min = std::move(result.r_min); @@ -316,7 +363,8 @@ StepResult nuts_step( rho_bck = rho; result = build_tree( theta_plus, r_plus, log_u, v, j, step_size, init_theta, r0, logp0, kin0, memo, - inv_mass_diag, rng, project_position, project_momentum + inv_mass_diag, rng, project_position, project_momentum, reverse_check, + reverse_check_tol ); theta_plus = std::move(result.theta_plus); r_plus = std::move(result.r_plus); @@ -327,6 +375,7 @@ StepResult nuts_step( } any_divergence = any_divergence || result.divergent; + any_non_reversible = any_non_reversible || result.non_reversible; alpha = result.alpha; n_alpha = result.n_alpha; @@ -364,6 +413,7 @@ StepResult nuts_step( auto diag = std::make_shared(); diag->tree_depth = j; diag->divergent = any_divergence; + diag->non_reversible = any_non_reversible; diag->energy = energy; return {theta, accept_prob, diag}; diff --git a/src/mcmc/algorithms/nuts.h b/src/mcmc/algorithms/nuts.h index 7f78b74a..be89f21b 100644 --- a/src/mcmc/algorithms/nuts.h +++ b/src/mcmc/algorithms/nuts.h @@ -31,6 +31,7 @@ struct BuildTreeResult { double alpha; ///< Sum of acceptance probabilities in the subtree int n_alpha; ///< Number of proposals contributing to alpha bool divergent; ///< Whether this subtree diverged + bool non_reversible; ///< Whether a non-reversible step was detected }; @@ -53,6 +54,8 @@ struct BuildTreeResult { * @param max_depth Maximum tree depth (default = 10) * @param project_position SHAKE position projection (nullptr for unconstrained) * @param project_momentum RATTLE momentum projection (nullptr for unconstrained) + * @param reverse_check Enable runtime reversibility check (constrained only) + * @param reverse_check_tol Factor for eps²-scaled reversibility tolerance * @return StepResult with position, acceptance probability, and NUTS diagnostics */ StepResult nuts_step( @@ -63,5 +66,7 @@ StepResult nuts_step( SafeRNG& rng, int max_depth = 10, const ProjectPositionFn* project_position = nullptr, - const ProjectMomentumFn* project_momentum = nullptr + const ProjectMomentumFn* project_momentum = nullptr, + bool reverse_check = true, + double reverse_check_tol = 0.5 ); diff --git a/src/mcmc/execution/chain_result.h b/src/mcmc/execution/chain_result.h index ae1b1e3c..d03dd5ed 100644 --- a/src/mcmc/execution/chain_result.h +++ b/src/mcmc/execution/chain_result.h @@ -41,6 +41,8 @@ class ChainResult { arma::ivec treedepth_samples; /// NUTS/HMC divergent transition flags (n_iter). arma::ivec divergent_samples; + /// NUTS/HMC non-reversible step flags (n_iter). + arma::ivec non_reversible_samples; /// NUTS/HMC energy diagnostic (n_iter). arma::vec energy_samples; /// Whether NUTS/HMC diagnostics are stored. @@ -82,6 +84,7 @@ class ChainResult { void reserve_nuts_diagnostics(const size_t n_iter) { treedepth_samples.set_size(n_iter); divergent_samples.set_size(n_iter); + non_reversible_samples.set_size(n_iter); energy_samples.set_size(n_iter); has_nuts_diagnostics = true; } @@ -120,9 +123,10 @@ class ChainResult { * @param divergent Whether a divergence occurred * @param energy Final Hamiltonian energy */ - void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, double energy) { + void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, bool non_reversible, double energy) { treedepth_samples(iter) = tree_depth; divergent_samples(iter) = divergent ? 1 : 0; + non_reversible_samples(iter) = non_reversible ? 1 : 0; energy_samples(iter) = energy; } }; diff --git a/src/mcmc/execution/chain_runner.cpp b/src/mcmc/execution/chain_runner.cpp index 442f5abc..1aef3f4a 100644 --- a/src/mcmc/execution/chain_runner.cpp +++ b/src/mcmc/execution/chain_runner.cpp @@ -87,7 +87,7 @@ void run_mcmc_chain( if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { auto* diag = dynamic_cast(result.diagnostics.get()); if (diag) { - chain_result.store_nuts_diagnostics(sample_index, diag->tree_depth, diag->divergent, diag->energy); + chain_result.store_nuts_diagnostics(sample_index, diag->tree_depth, diag->divergent, diag->non_reversible, diag->energy); } } @@ -219,6 +219,7 @@ Rcpp::List convert_results_to_list(const std::vector& results) { if (chain.has_nuts_diagnostics) { chain_list["treedepth"] = chain.treedepth_samples; chain_list["divergent"] = chain.divergent_samples; + chain_list["non_reversible"] = chain.non_reversible_samples; chain_list["energy"] = chain.energy_samples; } } diff --git a/src/mcmc/execution/sampler_config.h b/src/mcmc/execution/sampler_config.h index 595c819a..2305a194 100644 --- a/src/mcmc/execution/sampler_config.h +++ b/src/mcmc/execution/sampler_config.h @@ -35,6 +35,12 @@ struct SamplerConfig { /// Enable missing-data imputation during sampling. bool na_impute = false; + /// Enable runtime reversibility check for constrained integration. + bool reverse_check = true; + + /// Factor for the eps²-scaled reversibility tolerance (tol = factor * eps²). + double reverse_check_tol = 0.5; + /// Random seed. int seed = 42; diff --git a/src/mcmc/execution/step_result.h b/src/mcmc/execution/step_result.h index 4768769f..6e3b986e 100644 --- a/src/mcmc/execution/step_result.h +++ b/src/mcmc/execution/step_result.h @@ -25,9 +25,10 @@ struct DiagnosticsBase { * NUTSDiagnostics - Per-iteration NUTS diagnostics (derives from DiagnosticsBase) */ struct NUTSDiagnostics : public DiagnosticsBase { - int tree_depth; ///< Depth of the trajectory tree - bool divergent; ///< Whether a divergence occurred - double energy; ///< Final Hamiltonian (-log posterior + kinetic energy) + int tree_depth; ///< Depth of the trajectory tree + bool divergent; ///< Whether a divergence occurred + bool non_reversible; ///< Whether a non-reversible constrained step occurred + double energy; ///< Final Hamiltonian (-log posterior + kinetic energy) }; diff --git a/src/mcmc/samplers/hmc_sampler.h b/src/mcmc/samplers/hmc_sampler.h index 1a6d7fb4..a4812bba 100644 --- a/src/mcmc/samplers/hmc_sampler.h +++ b/src/mcmc/samplers/hmc_sampler.h @@ -19,7 +19,9 @@ class HMCSampler : public GradientSamplerBase { public: explicit HMCSampler(const SamplerConfig& config, WarmupSchedule& schedule) : GradientSamplerBase(config.initial_step_size, config.target_acceptance, schedule), - num_leapfrogs_(config.num_leapfrogs) + num_leapfrogs_(config.num_leapfrogs), + reverse_check_(config.reverse_check), + reverse_check_tol_(config.reverse_check_tol) {} protected: @@ -74,11 +76,15 @@ class HMCSampler : public GradientSamplerBase { StepResult result = hmc_step( x, step_size_, joint_fn, num_leapfrogs_, inv_mass, - proj_pos, proj_mom, rng); + proj_pos, proj_mom, rng, + reverse_check_ && enforce_reverse_check_, + reverse_check_tol_); model.set_full_position(result.state); return result; } int num_leapfrogs_; + bool reverse_check_; + double reverse_check_tol_; }; diff --git a/src/mcmc/samplers/nuts_sampler.h b/src/mcmc/samplers/nuts_sampler.h index f73c7e97..4be8e2aa 100644 --- a/src/mcmc/samplers/nuts_sampler.h +++ b/src/mcmc/samplers/nuts_sampler.h @@ -19,7 +19,9 @@ class NUTSSampler : public GradientSamplerBase { public: explicit NUTSSampler(const SamplerConfig& config, WarmupSchedule& schedule) : GradientSamplerBase(config.initial_step_size, config.target_acceptance, schedule), - max_tree_depth_(config.max_tree_depth) + max_tree_depth_(config.max_tree_depth), + reverse_check_(config.reverse_check), + reverse_check_tol_(config.reverse_check_tol) {} bool has_nuts_diagnostics() const override { return true; } @@ -76,7 +78,9 @@ class NUTSSampler : public GradientSamplerBase { StepResult result = nuts_step( x, step_size_, joint_fn, inv_mass, rng, max_tree_depth_, - &proj_pos, &proj_mom + &proj_pos, &proj_mom, + reverse_check_ && enforce_reverse_check_, + reverse_check_tol_ ); model.set_full_position(result.state); @@ -84,4 +88,6 @@ class NUTSSampler : public GradientSamplerBase { } int max_tree_depth_; + bool reverse_check_; + double reverse_check_tol_; }; diff --git a/src/mcmc/samplers/sampler_base.h b/src/mcmc/samplers/sampler_base.h index efe5a03e..09c0f9c3 100644 --- a/src/mcmc/samplers/sampler_base.h +++ b/src/mcmc/samplers/sampler_base.h @@ -114,6 +114,11 @@ class GradientSamplerBase : public SamplerBase { // Use adaptation controller's current step size for this iteration step_size_ = adapt_->current_step_size(); + // Phase-aware reverse check: observe during warmup, enforce during sampling. + // The check always runs (recording non_reversible counts), + // but only terminates trees / rejects steps when enforcing. + enforce_reverse_check_ = schedule_.sampling(iteration); + StepResult result = do_gradient_step(model); // Let the adaptation controller handle step-size and mass-matrix logic. @@ -186,6 +191,10 @@ class GradientSamplerBase : public SamplerBase { double step_size_; double target_acceptance_; + /// Whether the reverse check should enforce (reject) this iteration. + /// Set in step() before do_gradient_step(). Derived classes read this. + bool enforce_reverse_check_ = false; + public: /** * Initialize the adaptation controller and run the step-size heuristic. diff --git a/src/models/mixed/mixed_mrf_metropolis.cpp b/src/models/mixed/mixed_mrf_metropolis.cpp index 74c9ea8a..7901bf92 100644 --- a/src/models/mixed/mixed_mrf_metropolis.cpp +++ b/src/models/mixed/mixed_mrf_metropolis.cpp @@ -1,4 +1,5 @@ #include +#include #include "models/mixed/mixed_mrf_model.h" #include "rng/rng_utils.h" #include "mcmc/execution/step_result.h" diff --git a/tests/testthat/test-fit-object-contract.R b/tests/testthat/test-fit-object-contract.R index 1c72d16e..fa6ced14 100644 --- a/tests/testthat/test-fit-object-contract.R +++ b/tests/testthat/test-fit-object-contract.R @@ -319,7 +319,8 @@ for(spec in get_bgms_fixtures()) { # (This is safe because the fixture cache returns the same object.) # Instead, just verify that accessing a summary field returns data. summary_fields = grep( - "^posterior_summary_", names(fit), value = TRUE + "^posterior_summary_", names(fit), + value = TRUE ) for(field in summary_fields) { val = fit[[field]] @@ -345,7 +346,8 @@ for(spec in get_bgmcompare_fixtures()) { { fit = spec$get_fit() summary_fields = grep( - "^posterior_summary_", names(fit), value = TRUE + "^posterior_summary_", names(fit), + value = TRUE ) for(field in summary_fields) { val = fit[[field]] diff --git a/tests/testthat/test-reversibility-check.R b/tests/testthat/test-reversibility-check.R new file mode 100644 index 00000000..80affe3d --- /dev/null +++ b/tests/testthat/test-reversibility-check.R @@ -0,0 +1,301 @@ +# --------------------------------------------------------------------------- # +# Reversibility check — runtime forward-backward round-trip for RATTLE. +# +# Tests verify: +# 1. Normal steps pass the reversibility check (reversible = TRUE) +# 2. A strict tolerance triggers non-reversible detections +# 3. The non_reversible diagnostic flows through to bgm() output +# 4. Stress: large step sizes amplify reversibility error +# 5. Stress: long trajectories (200 steps) remain reversible +# 6. Stress: small n / large p rank-deficient regime +# 7. Stress: ill-conditioned data (high condition number) +# 8. Stress: high-dimensional GGM (p=20) integration test +# --------------------------------------------------------------------------- # + + +# ---- Helpers (reused from test-rattle-leapfrog.R) --------------------------- + +make_test_phi_rc = function(p, seed = 42) { + set.seed(seed) + A = matrix(rnorm(p * p), p, p) + K = A %*% t(A) + diag(p) + Phi = chol(K) + list(Phi = Phi, K = K, p = p) +} + +phi_to_full_position_rc = function(Phi) { + p = nrow(Phi) + x = numeric(p * (p + 1) / 2) + idx = 1 + for(q in seq_len(p)) { + if(q > 1) { + for(i in seq_len(q - 1)) { + x[idx] = Phi[i, q] + idx = idx + 1 + } + } + x[idx] = log(Phi[q, q]) + idx = idx + 1 + } + x +} + +make_scenario_rc = function(p, edges, seed) { + dat = make_test_phi_rc(p, seed = seed) + n = 10 + set.seed(seed + 1000) + X = matrix(rnorm(n * p), n, p) + S = t(X) %*% X + scale = 2.5 + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(seed + 2000) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + list(x0 = x0, r0 = r0, S = S, n = n, scale = scale, p = p, edges = edges) +} + + +# ---- 1. Normal steps pass the reversibility check -------------------------- + +test_that("constrained leapfrog checked passes for well-behaved steps", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + + sc = make_scenario_rc(p, edges, seed = 300) + eps = 0.005 + n_steps = 10 + + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, + reverse_check_tol = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +test_that("checked leapfrog matches unchecked positions", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + edges[2, 4] = 0L + edges[4, 2] = 0L + + sc = make_scenario_rc(p, edges, seed = 301) + eps = 0.005 + n_steps = 8 + + checked = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale + ) + unchecked = ggm_test_leapfrog_constrained( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale + ) + + expect_equal(as.vector(checked$x), as.vector(unchecked$x), tolerance = 1e-12) + expect_equal(as.vector(checked$r), as.vector(unchecked$r), tolerance = 1e-12) +}) + + +# ---- 2. Strict tolerance triggers non-reversible detections ---------------- + +test_that("extreme tolerance detects non-reversible steps", { + p = 4 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 3] = 0L + edges[3, 1] = 0L + edges[2, 4] = 0L + edges[4, 2] = 0L + + sc = make_scenario_rc(p, edges, seed = 302) + eps = 0.01 + n_steps = 20 + + # With an impossibly tight factor, expect some failures + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, eps, n_steps, sc$S, sc$n, edges, sc$scale, + reverse_check_tol = 1e-11 + ) + + expect_gt(result$non_reversible_count, 0L) +}) + + +# ---- 3. Integration: non_reversible diagnostic in bgm() output ------------- + +test_that("bgm() nuts_diag includes non_reversible field", { + skip_on_cran() + + set.seed(1) + p = 4 + n = 80 + data = matrix(rnorm(n * p), n, p) + colnames(data) = paste0("V", seq_len(p)) + + fit = bgm( + data, + variable_type = "continuous", + iter = 10, + warmup = 50, + edge_selection = TRUE, + chains = 1, + display_progress = "none" + ) + + if(!is.null(fit$nuts_diag)) { + expect_true("non_reversible" %in% names(fit$nuts_diag)) + expect_true("total_non_reversible" %in% names(fit$nuts_diag$summary)) + } +}) + + +# ---- 4. Stress: large step sizes ------------------------------------------- + +test_that("large step sizes increase reversibility error", { + p = 5 + edges = matrix(1L, p, p) + diag(edges) = 0L + + sc = make_scenario_rc(p, edges, seed = 400) + + # Small step: should be reversible + small = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.005, 10, sc$S, sc$n, edges, sc$scale, + reverse_check_tol = 0.5 + ) + expect_equal(small$non_reversible_count, 0L) + + # Large step (100x): near-machine-epsilon tolerance catches round-trip drift + large = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.5, 20, sc$S, sc$n, edges, sc$scale, + reverse_check_tol = 1e-15 + ) + expect_gt(large$non_reversible_count, 0L) +}) + + +# ---- 5. Stress: long trajectories (200 steps) ------------------------------ + +test_that("long trajectory (200 steps) stays reversible at normal tolerance", { + p = 5 + edges = matrix(1L, p, p) + diag(edges) = 0L + edges[1, 4] = 0L + edges[4, 1] = 0L + + sc = make_scenario_rc(p, edges, seed = 500) + + result = ggm_test_leapfrog_constrained_checked( + sc$x0, sc$r0, 0.005, 200, sc$S, sc$n, edges, sc$scale, + reverse_check_tol = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 6. Stress: small n / large p (rank-deficient S) ----------------------- + +test_that("rank-deficient regime (n < p) stays reversible", { + p = 8 + edges = matrix(1L, p, p) + diag(edges) = 0L + + # Build scenario with n < p (rank-deficient sufficient statistic) + dat = make_test_phi_rc(p, seed = 600) + n = 5 + set.seed(601) + X = matrix(rnorm(n * p), n, p) + S = t(X) %*% X # rank = min(n, p) = 5 < 8 + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(602) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + result = ggm_test_leapfrog_constrained_checked( + x0, r0, 0.005, 30, S, n, edges, 2.5, + reverse_check_tol = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 7. Stress: ill-conditioned data (high condition number) ---------------- + +test_that("ill-conditioned data (kappa ~ 1e4) stays reversible", { + p = 6 + edges = matrix(1L, p, p) + diag(edges) = 0L + + # Build ill-conditioned sufficient statistic + dat = make_test_phi_rc(p, seed = 700) + n = 50 + set.seed(701) + X = matrix(rnorm(n * p), n, p) + # Stretch first column to create high condition number + X[, 1] = X[, 1] * 100 + S = t(X) %*% X + # Verify ill-conditioning + ev = eigen(S, symmetric = TRUE, only.values = TRUE)$values + kappa = max(ev) / min(ev) + expect_gt(kappa, 1e3) + + x_raw = phi_to_full_position_rc(dat$Phi) + proj = ggm_test_project_position(x_raw, edges) + x0 = as.vector(proj$x_projected) + + set.seed(702) + r_raw = rnorm(length(x0)) + r0 = as.vector(ggm_test_project_momentum(r_raw, x0, edges)) + + result = ggm_test_leapfrog_constrained_checked( + x0, r0, 0.003, 30, S, n, edges, 2.5, + reverse_check_tol = 0.5 + ) + + expect_equal(result$non_reversible_count, 0L) +}) + + +# ---- 8. Stress: high-dimensional GGM (p=20) integration -------------------- + +test_that("high-dimensional GGM (p=20) completes without crash", { + skip_on_cran() + + set.seed(800) + p = 20 + n = 100 + data = matrix(rnorm(n * p), n, p) + colnames(data) = paste0("V", seq_len(p)) + + fit = bgm( + data, + variable_type = "continuous", + iter = 10, + warmup = 50, + edge_selection = TRUE, + chains = 1, + display_progress = "none" + ) + + expect_s3_class(fit, "bgms") + if(!is.null(fit$nuts_diag)) { + expect_true("non_reversible" %in% names(fit$nuts_diag)) + } +}) diff --git a/vignettes/diagnostics.Rmd b/vignettes/diagnostics.Rmd index 5b8b2726..f6136b3c 100644 --- a/vignettes/diagnostics.Rmd +++ b/vignettes/diagnostics.Rmd @@ -156,6 +156,12 @@ If you see a large number of divergences, consider increasing `target_accept` (w NUTS builds trajectories by repeatedly doubling their length until a "U-turn" criterion is satisfied. If the trajectory frequently reaches the maximum allowed depth (`nuts_max_depth`, default 10), it suggests the sampler may benefit from longer trajectories to explore the posterior efficiently. Hitting the maximum depth occasionally is normal; hitting it on most iterations may indicate challenging posterior geometry. If this happens, consider increasing `nuts_max_depth`. +## Non-reversible steps + +For MRFs with continuous variables, the leapfrog integrator enforces equality constraints through a projection step. After each forward step, the integrator checks whether reversing the step returns to the starting point. When the round-trip error exceeds a tolerance scaled by the square of the step size, the step is flagged as non-reversible. + +A small number of non-reversible steps is not a concern. A large number indicates that the step size is too large for the constraint geometry. Because the step size is tuned during warmup, the most effective remedy is to increase `warmup` so the adapter has more time to find an appropriate step size. If non-reversible steps persist after increasing warmup, switch to `update_method = "adaptive-metropolis"`. + ## Warmup and equilibration Standard HMC/NUTS warmup is designed to tune the step size and mass matrix for the continuous parameters. In models with edge selection, the discrete graph structure may take longer to reach its stationary distribution than the continuous parameters. As a result, even after warmup completes, the first portion of the sampling phase may still show transient behavior (i.e., non-stationarity).