Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ ggm_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, suf_stat,
.Call(`_bgms_ggm_test_leapfrog_constrained`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in)
}

ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor = 0.5, inv_mass_in = NULL) {
.Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in)
}

.compute_ess_cpp <- function(array3d) {
.Call(`_bgms_compute_ess_cpp`, array3d)
}
Expand Down
24 changes: 12 additions & 12 deletions R/bgm_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,20 @@ bgm_spec = function(x,

# --- Sampler (needs is_continuous and edge_selection early) ------------------
sampler = validate_sampler(
update_method = update_method,
target_accept = target_accept,
iter = iter,
warmup = warmup,
update_method = update_method,
target_accept = target_accept,
iter = iter,
warmup = warmup,
hmc_num_leapfrogs = hmc_num_leapfrogs,
nuts_max_depth = nuts_max_depth,
nuts_max_depth = nuts_max_depth,
learn_mass_matrix = learn_mass_matrix,
chains = chains,
cores = cores,
seed = seed,
display_progress = display_progress,
is_continuous = is_continuous,
edge_selection = if(model_type == "compare") FALSE else edge_selection,
verbose = verbose
chains = chains,
cores = cores,
seed = seed,
display_progress = display_progress,
is_continuous = is_continuous,
edge_selection = if(model_type == "compare") FALSE else edge_selection,
verbose = verbose
)

# --- Build by model type ----------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions R/build_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ build_output_bgm = function(spec, raw) {
}
if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth
if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent
if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible
if(!is.null(chain$energy)) res[["energy__"]] = chain$energy
res
})
Expand All @@ -330,6 +331,7 @@ build_output_bgm = function(spec, raw) {
}
if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth
if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent
if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible
if(!is.null(chain$energy)) res[["energy__"]] = chain$energy
res
})
Expand Down Expand Up @@ -593,6 +595,7 @@ build_output_mixed_mrf = function(spec, raw) {
}
if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth
if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent
if(!is.null(chain$non_reversible)) res[["non_reversible__"]] = chain$non_reversible
if(!is.null(chain$energy)) res[["energy__"]] = chain$energy
res
})
Expand Down
20 changes: 19 additions & 1 deletion R/nuts_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ check_warmup_complete = function(energy_mat) {
# Returns: An invisible named list with:
# - treedepth: Integer matrix (chains x iterations).
# - divergent: Integer matrix (chains x iterations), 0/1.
# - non_reversible: Integer matrix (chains x iterations), 0/1.
# - energy: Numeric matrix (chains x iterations).
# - ebfmi: Numeric vector of per-chain E-BFMI values.
# - warmup_check: Output of check_warmup_complete().
# - summary: List with total_divergences, max_tree_depth_hits,
# min_ebfmi, and warmup_incomplete (logical).
# min_ebfmi, total_non_reversible, and warmup_incomplete (logical).
# ------------------------------------------------------------------------------
summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE) {
nuts_chains = Filter(function(chain) {
Expand All @@ -134,6 +135,12 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE)
divergent_mat = combine_diag("divergent__")
energy_mat = combine_diag("energy__")

non_reversible_mat = if("non_reversible__" %in% names(nuts_chains[[1]])) {
combine_diag("non_reversible__")
} else {
matrix(0L, nrow = nrow(divergent_mat), ncol = ncol(divergent_mat))
}

# E-BFMI per chain
compute_ebfmi = function(energy) {
mean(diff(energy)^2) / stats::var(energy)
Expand All @@ -145,6 +152,7 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE)
# Summaries
n_total = nrow(divergent_mat) * ncol(divergent_mat)
total_divergences = sum(divergent_mat)
total_non_reversible = sum(non_reversible_mat)
max_tree_depth_hits = sum(treedepth_mat == nuts_max_depth)
min_ebfmi = min(ebfmi_per_chain)
low_ebfmi_chains = which(ebfmi_per_chain < 0.2)
Expand All @@ -169,6 +177,14 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE)
}
}

if(total_non_reversible > 0) {
non_rev_rate = total_non_reversible / n_total
issues = c(issues, sprintf(
"Non-reversible steps: %d (%.3f%%) - constrained integrator round-trip failed",
total_non_reversible, 100 * non_rev_rate
))
}

if(max_tree_depth_hits > 0) {
if(depth_hit_rate > 0.01) {
issues = c(issues, sprintf(
Expand Down Expand Up @@ -203,11 +219,13 @@ summarize_nuts_diagnostics = function(out, nuts_max_depth = 10, verbose = TRUE)
invisible(list(
treedepth = treedepth_mat,
divergent = divergent_mat,
non_reversible = non_reversible_mat,
energy = energy_mat,
ebfmi = ebfmi_per_chain,
warmup_check = warmup_check,
summary = list(
total_divergences = total_divergences,
total_non_reversible = total_non_reversible,
max_tree_depth_hits = max_tree_depth_hits,
min_ebfmi = min_ebfmi,
warmup_incomplete = any(warmup_check$warmup_incomplete)
Expand Down
18 changes: 9 additions & 9 deletions R/validate_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,16 @@ validate_sampler = function(update_method,
progress_type = progress_type_from_display_progress(display_progress)

list(
update_method = update_method,
target_accept = target_accept,
iter = iter,
warmup = warmup,
update_method = update_method,
target_accept = target_accept,
iter = iter,
warmup = warmup,
hmc_num_leapfrogs = hmc_num_leapfrogs,
nuts_max_depth = nuts_max_depth,
nuts_max_depth = nuts_max_depth,
learn_mass_matrix = learn_mass_matrix,
chains = chains,
cores = cores,
seed = seed,
progress_type = progress_type
chains = chains,
cores = cores,
seed = seed,
progress_type = progress_type
)
}
21 changes: 21 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,26 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// ggm_test_leapfrog_constrained_checked
Rcpp::List ggm_test_leapfrog_constrained_checked(const arma::vec& x0, const arma::vec& r0, double step_size, int n_steps, const arma::mat& suf_stat, int n, const arma::imat& edge_indicators, double pairwise_scale, double reverse_check_factor, Rcpp::Nullable<Rcpp::NumericVector> inv_mass_in);
RcppExport SEXP _bgms_ggm_test_leapfrog_constrained_checked(SEXP x0SEXP, SEXP r0SEXP, SEXP step_sizeSEXP, SEXP n_stepsSEXP, SEXP suf_statSEXP, SEXP nSEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP reverse_check_factorSEXP, SEXP inv_mass_inSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::vec& >::type x0(x0SEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type r0(r0SEXP);
Rcpp::traits::input_parameter< double >::type step_size(step_sizeSEXP);
Rcpp::traits::input_parameter< int >::type n_steps(n_stepsSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type suf_stat(suf_statSEXP);
Rcpp::traits::input_parameter< int >::type n(nSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type edge_indicators(edge_indicatorsSEXP);
Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP);
Rcpp::traits::input_parameter< double >::type reverse_check_factor(reverse_check_factorSEXP);
Rcpp::traits::input_parameter< Rcpp::Nullable<Rcpp::NumericVector> >::type inv_mass_in(inv_mass_inSEXP);
rcpp_result_gen = Rcpp::wrap(ggm_test_leapfrog_constrained_checked(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_factor, inv_mass_in));
return rcpp_result_gen;
END_RCPP
}
// compute_ess_cpp
Rcpp::NumericVector compute_ess_cpp(Rcpp::NumericVector array3d);
RcppExport SEXP _bgms_compute_ess_cpp(SEXP array3dSEXP) {
Expand Down Expand Up @@ -642,6 +662,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_bgms_ggm_test_logp_and_gradient_full", (DL_FUNC) &_bgms_ggm_test_logp_and_gradient_full, 5},
{"_bgms_ggm_test_project_momentum", (DL_FUNC) &_bgms_ggm_test_project_momentum, 4},
{"_bgms_ggm_test_leapfrog_constrained", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained, 9},
{"_bgms_ggm_test_leapfrog_constrained_checked", (DL_FUNC) &_bgms_ggm_test_leapfrog_constrained_checked, 10},
{"_bgms_compute_ess_cpp", (DL_FUNC) &_bgms_compute_ess_cpp, 1},
{"_bgms_compute_rhat_cpp", (DL_FUNC) &_bgms_compute_rhat_cpp, 1},
{"_bgms_compute_indicator_ess_cpp", (DL_FUNC) &_bgms_compute_indicator_ess_cpp, 1},
Expand Down
63 changes: 63 additions & 0 deletions src/ggm_gradient_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,66 @@ Rcpp::List ggm_test_leapfrog_constrained(
Rcpp::Named("dH") = H_final - H0
);
}


// [[Rcpp::export]]
Rcpp::List ggm_test_leapfrog_constrained_checked(
const arma::vec& x0,
const arma::vec& r0,
double step_size,
int n_steps,
const arma::mat& suf_stat,
int n,
const arma::imat& edge_indicators,
double pairwise_scale,
double reverse_check_factor = 0.5,
Rcpp::Nullable<Rcpp::NumericVector> inv_mass_in = R_NilValue)
{
size_t p = edge_indicators.n_rows;

arma::mat inc_prob(p, p, arma::fill::value(0.5));
GGMModel model(static_cast<int>(p), suf_stat, inc_prob,
edge_indicators, true, pairwise_scale);

Memoizer::JointFn joint = [&model](const arma::vec& x)
-> std::pair<double, arma::vec> {
return model.logp_and_gradient_full(x);
};
Memoizer memo(joint);

arma::vec inv_mass;
if(inv_mass_in.isNotNull()) {
inv_mass = Rcpp::as<arma::vec>(inv_mass_in);
} else {
inv_mass = arma::ones<arma::vec>(x0.n_elem);
}

ProjectPositionFn proj_pos = [&model, &inv_mass](arma::vec& x) {
model.project_position(x, inv_mass);
};
ProjectMomentumFn proj_mom = [&model, &inv_mass](arma::vec& r, const arma::vec& x) {
model.project_momentum(r, x, inv_mass);
};

arma::vec x = x0;
arma::vec r = r0;
int non_reversible_count = 0;

for (int s = 0; s < n_steps; ++s) {
auto result = leapfrog_constrained_checked(
x, r, step_size, memo, inv_mass, proj_pos, proj_mom,
reverse_check_factor
);
x = std::move(result.theta);
r = std::move(result.r);
if (!result.reversible) {
non_reversible_count++;
}
}

return Rcpp::List::create(
Rcpp::Named("x") = Rcpp::wrap(x),
Rcpp::Named("r") = Rcpp::wrap(r),
Rcpp::Named("non_reversible_count") = non_reversible_count
);
}
32 changes: 27 additions & 5 deletions src/mcmc/algorithms/hmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ StepResult hmc_step(
const arma::vec& inv_mass_diag,
const ProjectPositionFn& project_position,
const ProjectMomentumFn& project_momentum,
SafeRNG& rng
SafeRNG& rng,
bool reverse_check,
double reverse_check_tol
) {
Memoizer memo(joint);

Expand All @@ -198,11 +200,31 @@ StepResult hmc_step(

// Run num_leapfrogs constrained leapfrog steps
arma::vec theta = init_theta;
bool non_reversible = false;
for (int i = 0; i < num_leapfrogs; ++i) {
std::tie(theta, r) = leapfrog_constrained(
theta, r, step_size, memo, inv_mass_diag,
project_position, project_momentum
);
if (reverse_check) {
auto checked = leapfrog_constrained_checked(
theta, r, step_size, memo, inv_mass_diag,
project_position, project_momentum,
reverse_check_tol
);
theta = std::move(checked.theta);
r = std::move(checked.r);
if (!checked.reversible) {
non_reversible = true;
break;
}
} else {
std::tie(theta, r) = leapfrog_constrained(
theta, r, step_size, memo, inv_mass_diag,
project_position, project_momentum
);
}
}

// Non-reversible step forces rejection
if (non_reversible) {
return {init_theta, 0.0};
}

double logp1 = memo.cached_log_post(theta);
Expand Down
6 changes: 5 additions & 1 deletion src/mcmc/algorithms/hmc.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ StepResult hmc_step(
* @param project_position SHAKE position projection callback
* @param project_momentum RATTLE momentum projection callback
* @param rng Thread-safe random number generator
* @param reverse_check Enable runtime reversibility check
* @param reverse_check_tol Factor for eps²-scaled reversibility tolerance
* @return StepResult with accepted state and acceptance probability
*/
StepResult hmc_step(
Expand All @@ -160,5 +162,7 @@ StepResult hmc_step(
const arma::vec& inv_mass_diag,
const ProjectPositionFn& project_position,
const ProjectMomentumFn& project_momentum,
SafeRNG& rng
SafeRNG& rng,
bool reverse_check = true,
double reverse_check_tol = kReverseCheckFactor
);
34 changes: 34 additions & 0 deletions src/mcmc/algorithms/leapfrog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,40 @@ std::pair<arma::vec, arma::vec> leapfrog_constrained(
}


ConstrainedLeapfrogResult leapfrog_constrained_checked(
const arma::vec& theta,
const arma::vec& r,
double eps,
Memoizer& memo,
const arma::vec& inv_mass_diag,
const ProjectPositionFn& project_position,
const ProjectMomentumFn& project_momentum,
double reverse_check_factor
) {
// --- Forward step ---
auto [theta_new, r_new] = leapfrog_constrained(
theta, r, eps, memo, inv_mass_diag,
project_position, project_momentum
);

// --- Backward step (negate momentum, step forward) ---
// Use a separate Memoizer so the caller's cache stays at theta_new.
Memoizer back_memo(memo.joint_fn);
arma::vec r_back = -r_new;
auto [theta_back, r_back_out] = leapfrog_constrained(
theta_new, r_back, eps, back_memo, inv_mass_diag,
project_position, project_momentum
);

// --- Reversibility check (eps^2-scaled max-norm) ---
double max_diff = arma::max(arma::abs(theta_back - theta));
double tol = reverse_check_factor * eps * eps;
bool reversible = (max_diff <= tol);

return {std::move(theta_new), std::move(r_new), reversible, max_diff};
}


LeapfrogJointResult leapfrog(
const arma::vec& theta_init,
const arma::vec& r_init,
Expand Down
Loading
Loading