Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ target_link_libraries(simulation_benchmark PRIVATE memilio ode_secir benchmark::
add_executable(graph_simulation_benchmark graph_simulation.cpp)
target_link_libraries(graph_simulation_benchmark PRIVATE memilio ode_secirvvs benchmark::benchmark)

add_executable(secirvvs_advance_benchmark secirvvs_advance.cpp)
target_link_libraries(secirvvs_advance_benchmark PRIVATE memilio ode_secirvvs benchmark::benchmark)

add_executable(abm_benchmark abm.cpp)
target_link_libraries(abm_benchmark PRIVATE abm benchmark::benchmark)
73 changes: 38 additions & 35 deletions cpp/benchmarks/flow_simulation_ode_secirvvs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class FlowlessModel : public CompartmentalModel<ScalarType, osecirvvs::Infection
{
using InfectionState = osecirvvs::InfectionState;
using Base = CompartmentalModel<ScalarType, osecirvvs::InfectionState,
mio::Populations<ScalarType, AgeGroup, osecirvvs::InfectionState>,
osecirvvs::Parameters<ScalarType>>;
mio::Populations<ScalarType, AgeGroup, osecirvvs::InfectionState>,
osecirvvs::Parameters<ScalarType>>;

public:
FlowlessModel(const Populations& pop, const ParameterSet& params)
Expand Down Expand Up @@ -539,7 +539,6 @@ class Simulation : public Base
*/
Eigen::Ref<Eigen::VectorX<ScalarType>> advance(ScalarType tmax)
{
auto& t_end_dyn_npis = this->get_model().parameters.get_end_dynamic_npis();
auto& dyn_npis =
this->get_model().parameters.template get<osecirvvs::DynamicNPIsInfectedSymptoms<ScalarType>>();
auto& contact_patterns = this->get_model().parameters.template get<osecirvvs::ContactPatterns<ScalarType>>();
Expand All @@ -549,45 +548,36 @@ class Simulation : public Base
this->get_model().parameters.template get<osecirvvs::TransmissionProbabilityOnContact<ScalarType>>();

ScalarType delay_npi_implementation;
auto t = Base::get_result().get_last_time();
const auto dt = dyn_npis.get_interval().get();
auto t = Base::get_result().get_last_time();
while (t < tmax) {

auto dt_eff = std::min({dt, tmax - t, m_t_last_npi_check + dt - t});
if (dt_eff >= 1.0) {
dt_eff = 1.0;
if (t > 0) {
delay_npi_implementation = ScalarType(dyn_npis.get_implementation_delay());
}
else { // DynamicNPIs for t=0 are 'misused' to be from-start NPIs. I.e., do not enforce delay.
delay_npi_implementation = 0;
}

if (t == 0) {
//this->apply_vaccination(t); // done in init now?
this->apply_variant(t, base_infectiousness);
}
Base::advance(t + dt_eff);
if (t + 0.5 + dt_eff - std::floor(t + 0.5) >= 1) {
this->apply_vaccination(t + 0.5 + dt_eff);
this->apply_variant(t, base_infectiousness);
}

if (t > 0) {
delay_npi_implementation = 7;
}
else {
delay_npi_implementation = 0;
}
t = t + dt_eff;

if (dyn_npis.get_thresholds().size() > 0) {
if (floating_point_greater_equal(t, m_t_last_npi_check + dt)) {
if (t < t_end_dyn_npis) {
auto inf_rel = get_infections_relative(*this, t, this->get_result().get_last_value()) *
dyn_npis.get_base_value();
auto exceeded_threshold = dyn_npis.get_max_exceeded_threshold(inf_rel);
if (exceeded_threshold != dyn_npis.get_thresholds().end() &&
(exceeded_threshold->first > m_dynamic_npi.first ||
t > ScalarType(m_dynamic_npi.second))) { //old npi was weaker or is expired

ScalarType direc_begin = ScalarType(dyn_npis.get_directive_begin());
ScalarType direc_end = ScalarType(dyn_npis.get_directive_end());
if (floating_point_greater_equal(t, direc_begin, 1e-10) && t < direc_end) {
auto inf_rel = get_infections_relative(*this, t, this->get_result().get_last_value()) *
dyn_npis.get_base_value();
auto exceeded_threshold = dyn_npis.get_max_exceeded_threshold(inf_rel);
if (exceeded_threshold != dyn_npis.get_thresholds().end() &&
(exceeded_threshold->first > m_dynamic_npi.first ||
t > ScalarType(m_dynamic_npi.second))) { // old npi was weaker or is expired

if (t + delay_npi_implementation < direc_end) {
auto t_start = SimulationTime<ScalarType>(t + delay_npi_implementation);
auto t_end = t_start + SimulationTime<ScalarType>(dyn_npis.get_duration());
// set the end to the minimum of start+delay and the end of the directive
auto t_end = SimulationTime<ScalarType>(
min<ScalarType>(direc_end, ScalarType(t_start + dyn_npis.get_duration())));
this->get_model().parameters.get_start_commuter_detection() = t_start.get();
this->get_model().parameters.get_end_commuter_detection() = t_end.get();
m_dynamic_npi = std::make_pair(exceeded_threshold->first, t_end);
Expand All @@ -597,12 +587,16 @@ class Simulation : public Base
});
}
}
m_t_last_npi_check = t;
}
}
else {
m_t_last_npi_check = t;

auto dt_eff = min<ScalarType>(1.0, tmax - t);
Base::advance(t + dt_eff);
if (t + 0.5 + dt_eff - std::floor(t + 0.5) >= 1) {
this->apply_vaccination(t + 0.5 + dt_eff);
this->apply_variant(t, base_infectiousness);
}
t = t + dt_eff;
}

this->get_model().parameters.template get<osecirvvs::TransmissionProbabilityOnContact<ScalarType>>() =
Expand Down Expand Up @@ -719,6 +713,15 @@ void setup_model(Model& model)
model.parameters.template get<osecirvvs::ReducTimeInfectedMild<ScalarType>>()[AgeGroup(0)] = 0.9;

model.parameters.template get<osecirvvs::Seasonality<ScalarType>>() = 0.2;

auto& npis = model.parameters.template get<osecirvvs::DynamicNPIsInfectedSymptoms<ScalarType>>();
auto npi_groups = Eigen::VectorXd::Ones(contact_matrix[0].get_num_groups());
npis.set_threshold(0.01 * 100'000, {DampingSampling<ScalarType>(0.5, DampingLevel(0), DampingType(0),
SimulationTime<ScalarType>(0), {0}, npi_groups)});
npis.set_base_value(100'000);
npis.set_implementation_delay(SimulationTime<ScalarType>(7.0));
npis.set_duration(SimulationTime<ScalarType>(14.0));

// The function apply_constraints() ensures that all parameters are within their defined bounds.
// Note that negative values are set to zero instead of stopping the simulation.
model.apply_constraints();
Expand Down
93 changes: 93 additions & 0 deletions cpp/benchmarks/secirvvs_advance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (C) 2020-2026 MEmilio
*
* Authors: Henrik Zunker
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* Benchmark comparing the overhead of the model-specific osecirvvs::Simulation::advance()
* (which uses substeps and calls apply_vaccination / apply_variant / dynamic NPI checks)
* versus the generic mio::Simulation::advance() (a single integrator call for the entire range).
*/

#include "ode_secirvvs/model.h"
#include "memilio/compartments/simulation.h"
#include "memilio/utils/logging.h"
#include "benchmark/benchmark.h"

using FP = double;
using Model = mio::osecirvvs::Model<FP>;

static Model make_model(bool with_npis = false)
{
constexpr int tmax = 10;
Model model(1);
model.populations[{mio::AgeGroup(0), mio::osecirvvs::InfectionState::InfectedSymptomsNaive}] = 100.0;
model.populations.set_difference_from_total({mio::AgeGroup(0), mio::osecirvvs::InfectionState::SusceptibleNaive},
10000.0);
model.parameters.get<mio::osecirvvs::DailyPartialVaccinations<FP>>().resize(mio::SimulationDay(size_t(tmax + 1)));
model.parameters.get<mio::osecirvvs::DailyFullVaccinations<FP>>().resize(mio::SimulationDay(size_t(tmax + 1)));
if (with_npis) {
auto& npis = model.parameters.get<mio::osecirvvs::DynamicNPIsInfectedSymptoms<FP>>();
npis.set_threshold(0.01 * 100'000, {mio::DampingSampling<FP>{1.0,
mio::DampingLevel(0),
mio::DampingType(0),
mio::SimulationTime<FP>(0),
{0},
Eigen::VectorXd::Ones(1)}});
npis.set_duration(mio::SimulationTime<FP>(14.0));
npis.set_base_value(100'000);
}
return model;
}

// Generic advance: single integrator call for the full [t0, tmax] range.
static void BM_generic(benchmark::State& state)
{
mio::set_log_level(mio::LogLevel::off);
auto model = make_model();
for (auto _ : state) {
mio::simulate<FP, Model>(0., 10., 0.1, model);
}
}

// Model-specific advance without dynamic NPIs: 1-day loop with apply_vaccination + apply_variant,
// dynamic NPI threshold check is skipped (thresholds empty).
static void BM_secirvvs_no_npis(benchmark::State& state)
{
mio::set_log_level(mio::LogLevel::off);
auto model = make_model(/*with_npis=*/false);
for (auto _ : state) {
mio::osecirvvs::simulate<FP>(0., 10., 0.1, model);
}
}

// Model-specific advance with dynamic NPIs: same as above plus get_infections_relative +
// threshold comparison on every day step.
static void BM_secirvvs_with_npis(benchmark::State& state)
{
mio::set_log_level(mio::LogLevel::off);
auto model = make_model(/*with_npis=*/true);
for (auto _ : state) {
mio::osecirvvs::simulate<FP>(0., 10., 0.1, model);
}
}

BENCHMARK(BM_generic)->Name("SECIRVVS generic advance");
BENCHMARK(BM_secirvvs_no_npis)->Name("SECIRVVS model-specific advance (no dynamic NPIs)");
BENCHMARK(BM_secirvvs_with_npis)->Name("SECIRVVS model-specific advance (with dynamic NPIs)");
BENCHMARK_MAIN();
1 change: 0 additions & 1 deletion cpp/examples/ode_secirvvs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ int main()
.get<mio::osecirvvs::DailyFullVaccinations<double>>()[{(mio::AgeGroup)0, mio::SimulationDay(i)}] =
num_vaccinations;
}
model.parameters.get<mio::osecirvvs::DynamicNPIsImplementationDelay<double>>() = 7;

auto& contacts = model.parameters.get<mio::osecirvvs::ContactPatterns<double>>();
auto& contact_matrix = contacts.get_cont_freq_mat();
Expand Down
Loading
Loading