From 52163a8e6cf0779c8b2859f60d5172471b6d15c5 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Mon, 16 Mar 2026 08:26:40 +0100 Subject: [PATCH 1/9] add python bindings for households and abstractparameterdistributions --- .../utils/abstract_parameter_distribution.h | 9 ++++++++- .../memilio/simulation/bindings/models/abm.cpp | 18 +++++++++++++++++- .../memilio/simulation/bindings/simulation.cpp | 7 +++++++ .../bindings/utils/custom_index_array.h | 2 +- .../bindings/utils/parameter_distributions.cpp | 9 +++++++++ .../bindings/utils/parameter_distributions.h | 2 ++ 6 files changed, 44 insertions(+), 3 deletions(-) diff --git a/cpp/memilio/utils/abstract_parameter_distribution.h b/cpp/memilio/utils/abstract_parameter_distribution.h index 6c30b020c7..4abfcd7919 100644 --- a/cpp/memilio/utils/abstract_parameter_distribution.h +++ b/cpp/memilio/utils/abstract_parameter_distribution.h @@ -27,10 +27,17 @@ #include "parameter_distributions.h" #include #include +#include namespace mio { +template +concept HasSampleFunction = requires(T t) { + { t.get_sample(std::declval()) } -> std::convertible_to; + { t.get_sample(std::declval()) } -> std::convertible_to; +}; + /** * @brief This class represents an arbitrary ParameterDistribution. * @see mio::ParameterDistribution @@ -44,7 +51,7 @@ class AbstractParameterDistribution * The implementation handed to the constructor should have get_sample function * overloaded with mio::RandomNumberGenerator and mio::abm::PersonalRandomNumberGenerator as input arguments */ - template + template AbstractParameterDistribution(Impl&& dist) : m_dist(std::make_shared(std::move(dist))) , sample_impl1([](void* d, RandomNumberGenerator& rng) { diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index 289395bf4c..c8e4b38133 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -27,6 +27,7 @@ //Includes from MEmilio #include "abm/simulation.h" +#include "abm/household.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -127,11 +128,15 @@ PYBIND11_MODULE(_simulation_abm, m) pymio::bind_CustomIndexArray, mio::abm::VirusVariant, mio::AgeGroup>( m, "_AgeParameterArray"); + pymio::bind_CustomIndexArray( + m, "_DistAgeParameterArray"); + pymio::bind_CustomIndexArray(m, "_boolAgeParameterArray"); + pymio::bind_CustomIndexArray(m, "_intAgeParameterArray"); pymio::bind_CustomIndexArray(m, "_TestData"); pymio::bind_Index(m, "ProtectionTypeIndex"); pymio::bind_ParameterSet(m, "ParametersBase"); pymio::bind_class(m, "Parameters") - .def(py::init()) + .def(py::init()) .def("check_constraints", &mio::abm::Parameters::check_constraints); pymio::bind_ParameterSet( @@ -233,6 +238,17 @@ PYBIND11_MODULE(_simulation_abm, m) py::arg("tmax")) .def_property_readonly("model", py::overload_cast<>(&mio::abm::Simulation<>::get_model)); + pymio::bind_class(m, "HouseholdMember") + .def(py::init(), py::arg("num_agegroups") = 1) + .def_property("age_weights", &mio::abm::HouseholdMember::get_age_weights, + &mio::abm::HouseholdMember::set_age_weight); + + pymio::bind_class(m, "HouseholdGroup").def(py::init<>()); + // .def("add_households", &mio::abm::HouseholdGroup::add_households, + // py::arg("households") = std::vector(), py::arg("num_households") = 1); + + pymio::bind_class(m, "Household").def(py::init<>()); + m.attr("__version__") = "dev"; } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp index 5bb3793ce1..cf94d3a27e 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp @@ -29,6 +29,7 @@ #include "epidemiology/simulation_day.h" #include "math/integrator.h" #include "mobility/metapopulation_mobility_instant.h" +#include "utils/abstract_parameter_distribution.h" #include "utils/date.h" #include "utils/logging.h" #include "utils/time_series.h" @@ -36,6 +37,7 @@ #include "utils/uncertain_value.h" #include "utils/index.h" #include "utils/custom_index_array.h" +#include "utils/random_number_generator.h" //Includes from MEmilio #include "memilio/mobility/metapopulation_mobility_instant.h" @@ -53,9 +55,12 @@ PYBIND11_MODULE(_simulation, m) { pymio::bind_parameter_distribution(m, "ParameterDistribution"); pymio::bind_parameter_distribution_normal(m, "ParameterDistributionNormal"); + pymio::bind_parameter_distribution_lognormal(m, "ParameterDistributionLogNormal"); pymio::bind_parameter_distribution_uniform(m, "ParameterDistributionUniform"); pymio::bind_uncertain_value(m, "UncertainValue"); + pymio::bind_abstract_parameter_distribution(m, "AbstractParameterDistribution"); + pymio::bind_CustomIndexArray, mio::AgeGroup>(m, "AgeGroupArray"); pymio::bind_class>(m, "AgeGroup") .def(py::init()); @@ -155,5 +160,7 @@ PYBIND11_MODULE(_simulation, m) mio::thread_local_rng().seed(mio::RandomNumberGenerator::generate_seeds()); }); + pymio::bind_random_number_generator(m, "RandomNumberGenerator"); + m.attr("__version__") = "dev"; } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/custom_index_array.h b/pycode/memilio-simulation/memilio/simulation/bindings/utils/custom_index_array.h index 2fd49f2c3f..694896333f 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/custom_index_array.h +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/custom_index_array.h @@ -231,7 +231,7 @@ void bind_CustomIndexArray(pybind11::module_& m, std::string const& name) //scalar assignment .def("__setitem__", &assign_scalar); - //scalar assignment with conversion from double + // scalar assignment with conversion from double if constexpr (std::is_convertible::value) { c.def("__setitem__", &assign_scalar); } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp index 28498ddd2a..15abfced48 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp @@ -51,6 +51,15 @@ void bind_parameter_distribution_normal(py::module_& m, std::string const& name) &mio::ParameterDistributionNormal::set_standard_dev); } +void bind_parameter_distribution_lognormal(py::module_& m, std::string const& name) +{ + bind_class( + m, name.c_str()) + .def(py::init(), py::arg("log_mean"), py::arg("log_stddev")) + .def_property_readonly("log_mean", &mio::ParameterDistributionLogNormal::get_log_mean) + .def_property_readonly("log_standard_dev", &mio::ParameterDistributionLogNormal::get_log_stddev); +} + void bind_parameter_distribution_uniform(py::module_& m, std::string const& name) { bind_class(m, diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.h b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.h index d82bce8ec4..e516e525c2 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.h +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.h @@ -29,6 +29,8 @@ void bind_parameter_distribution(pybind11::module_& m, std::string const& name); void bind_parameter_distribution_normal(pybind11::module_& m, std::string const& name); +void bind_parameter_distribution_lognormal(pybind11::module_& m, std::string const& name); + void bind_parameter_distribution_uniform(pybind11::module_& m, std::string const& name); } // namespace pymio From 4f49e04df9aba53a3180a3860948f5340b603b52 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Sat, 21 Mar 2026 18:57:51 +0100 Subject: [PATCH 2/9] exchange contact rates by contact matrix --- cpp/models/abm/model_functions.cpp | 21 +++++++++++++-------- cpp/models/abm/parameters.h | 7 ++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cpp/models/abm/model_functions.cpp b/cpp/models/abm/model_functions.cpp index 224572be4b..d1531883c6 100644 --- a/cpp/models/abm/model_functions.cpp +++ b/cpp/models/abm/model_functions.cpp @@ -42,13 +42,15 @@ ScalarType total_exposure_by_contacts(const ContactExposureRates& rates, const C if (age_receiver == age_transmitter && age_receiver_group_size > 1) // adjust for the person not meeting themself { - total_exposure += rates[{cell_index, virus, age_transmitter}] * - params.get()[{age_receiver, age_transmitter}] * age_receiver_group_size / - (age_receiver_group_size - 1); + total_exposure += + rates[{cell_index, virus, age_transmitter}] * + params.get()[0].get_baseline()((size_t)age_receiver, (size_t)age_transmitter) * + age_receiver_group_size / (age_receiver_group_size - 1); } else { - total_exposure += rates[{cell_index, virus, age_transmitter}] * - params.get()[{age_receiver, age_transmitter}]; + total_exposure += + rates[{cell_index, virus, age_transmitter}] * + params.get()[0].get_baseline()((size_t)age_receiver, (size_t)age_transmitter); } } return total_exposure; @@ -189,12 +191,15 @@ void adjust_contact_rates(Location& location, size_t num_agegroups) ScalarType total_contacts = 0.; // slizing would be preferred but is problematic since both Tags of ContactRates are AgeGroup for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { - total_contacts += location.get_infection_parameters().get()[{contact_from, contact_to}]; + total_contacts += location.get_infection_parameters().get()[0].get_baseline()( + (size_t)contact_from, (size_t)contact_to); } if (total_contacts > location.get_infection_parameters().get()) { for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { - location.get_infection_parameters().get()[{contact_from, contact_to}] = - location.get_infection_parameters().get()[{contact_from, contact_to}] * + location.get_infection_parameters().get()[0].get_baseline()((size_t)contact_from, + (size_t)contact_to) = + location.get_infection_parameters().get()[0].get_baseline()((size_t)contact_from, + (size_t)contact_to) * location.get_infection_parameters().get() / total_contacts; } } diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index c7151f17a0..dac769f4cb 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -41,6 +41,7 @@ #include "memilio/epidemiology/age_group.h" #include "memilio/epidemiology/damping.h" #include "memilio/epidemiology/contact_matrix.h" +#include "memilio/epidemiology/uncertain_matrix.h" #include #include @@ -727,11 +728,11 @@ struct MaximumContacts { * contact rates */ struct ContactRates { - using Type = CustomIndexArray; + using Type = ContactMatrixGroup; static Type get_default(AgeGroup size) { - return Type({size, size}, - 1.0); // amount of contacts from AgeGroup a to AgeGroup b per day + return Type( + 1, static_cast((size_t)size)); // amount of contacts from AgeGroup a to AgeGroup b per day } static std::string name() { From e18e5b2ff9bcaee8648ed29d8ed726b2fea90d55 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Sat, 21 Mar 2026 19:09:11 +0100 Subject: [PATCH 3/9] add python bindings for minimal example --- .../simulation/abm_minimal_example.py | 164 ++++++++++++++++++ .../simulation/bindings/models/abm.cpp | 82 +++++++-- .../simulation/bindings/simulation.cpp | 2 + .../utils/abstract_parameter_distribution.h | 40 +++++ .../bindings/utils/random_number_generator.h | 45 +++++ 5 files changed, 323 insertions(+), 10 deletions(-) create mode 100644 pycode/examples/simulation/abm_minimal_example.py create mode 100644 pycode/memilio-simulation/memilio/simulation/bindings/utils/abstract_parameter_distribution.h create mode 100644 pycode/memilio-simulation/memilio/simulation/bindings/utils/random_number_generator.h diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py new file mode 100644 index 0000000000..3ab99027dc --- /dev/null +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -0,0 +1,164 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Carlotta Gerstein +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +from memilio.simulation import AgeGroup +from memilio.simulation.abm import Model, VirusVariant +import memilio.simulation as mio + +import numpy as np +from typing import Tuple + +num_age_groups = 4 + +model = Model(num_age_groups) + +# model.rng([1, 2, 3, 4, 5, 6]) + +# Set parameters + +for age_group in range(num_age_groups): + model.parameters.TimeExposedToNoSymptoms[VirusVariant.Wildtype, mio.AgeGroup(age_group)] = mio.AbstractParameterDistribution(mio.ParameterDistributionLogNormal( + 4., 1.)) + + model.parameters.AgeGroupGotoSchool[AgeGroup(age_group)] = False + model.parameters.AgeGroupGotoWork[AgeGroup(age_group)] = False + +model.parameters.AgeGroupGotoSchool[AgeGroup(1)] = True +model.parameters.AgeGroupGotoWork[AgeGroup(2)] = True +model.parameters.AgeGroupGotoWork[AgeGroup(3)] = True + +model.parameters.check_constraints() + +# Set populations + +n_households = 10 + +child = mio.abm.HouseholdMember(num_age_groups) +child.age_weights[AgeGroup(0)] = 1. +child.age_weights[AgeGroup(1)] = 1. + +parent = mio.abm.HouseholdMember(num_age_groups) +parent.age_weights[AgeGroup(2)] = 1. +parent.age_weights[AgeGroup(3)] = 1. + +twoPersonHousehold_group = mio.abm.HouseholdGroup() +twoPersonHousehold_full = mio.abm.Household() +twoPersonHousehold_full.add_members(child, 1) +twoPersonHousehold_full.add_members(parent, 1) +twoPersonHousehold_group.add_households(twoPersonHousehold_full, n_households) +mio.abm.add_household_group_to_model(model, twoPersonHousehold_group) + +threePersonHousehold_group = mio.abm.HouseholdGroup() +threePersonHousehold_full = mio.abm.Household() +threePersonHousehold_full.add_members(child, 1) +threePersonHousehold_full.add_members(parent, 2) +threePersonHousehold_group.add_households( + threePersonHousehold_full, n_households) +mio.abm.add_household_group_to_model(model, threePersonHousehold_group) + +# Set locations + +event = model.add_location(mio.abm.LocationType.SocialEvent) +model.get_location(event).infection_parameters.MaximumContacts = 5 + +hospital = model.add_location(mio.abm.LocationType.Hospital) +model.get_location(hospital).infection_parameters.MaximumContacts = 5 +icu = model.add_location(mio.abm.LocationType.ICU) +model.get_location(icu).infection_parameters.MaximumContacts = 5 + +shop = model.add_location(mio.abm.LocationType.BasicsShop) +model.get_location(shop).infection_parameters.MaximumContacts = 20 + +school = model.add_location(mio.abm.LocationType.School) +model.get_location(school).infection_parameters.MaximumContacts = 20 + +work = model.add_location(mio.abm.LocationType.Work) +model.get_location(work).infection_parameters.MaximumContacts = 20 + +model.parameters.AerosolTransmissionRates[VirusVariant.Wildtype] = 10 + +contacts = np.zeros((num_age_groups, num_age_groups)) +contacts[2, 3] = 10 + +model.get_location( + work).infection_parameters.ContactRates[0].baseline = contacts + +# Testing Schemes + +validity_period = mio.abm.days(1) +probability = 0.5 +start_date = mio.abm.TimePoint(0) +end_date = mio.abm.TimePoint(0) + mio.abm.days(10) +test_type = mio.abm.TestType.Antigen +test_parameters = model.parameters.TestData[test_type] + +testing_criteria_work = mio.abm.TestingCriteria() +testing_scheme_work = mio.abm.TestingScheme( + testing_criteria_work, validity_period, start_date, end_date, test_parameters, probability) + +model.testing_strategy.add_scheme( + mio.abm.LocationType.Work, testing_scheme_work) + +# Seed infections + +infection_distribution = [0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0] +for person in model.persons: + infection_state = mio.abm.InfectionState(np.random.choice( + len(infection_distribution), p=infection_distribution)) + rng = mio.abm.PersonalRandomNumberGenerator(person) + + if infection_state != mio.abm.InfectionState.Susceptible: + person.add_new_infection(rng, mio.abm.VirusVariant.Wildtype, + person.age, model.parameters, start_date, infection_state) + +for person in model.persons: + id = person.id + + model.assign_location(id, event) + model.assign_location(id, shop) + + model.assign_location(id, hospital) + model.assign_location(id, icu) + + if person.age == AgeGroup(1): + model.assign_location(id, school) + + if person.age == AgeGroup(2) or person.age == AgeGroup(3): + model.assign_location(id, work) + +t_lockdown = mio.abm.TimePoint(0) + mio.abm.days(10) +mio.abm.close_social_events(t_lockdown, 0.9, model.parameters) + +t0 = mio.abm.TimePoint(0) +tmax = t0 + mio.abm.days(10) +sim = mio.abm.Simulation(t0, model) + + +history = mio.abm.History(mio.TimeSeries(num_age_groups)) + +sim.advance(tmax, history) + +for person in sim.model.persons: + # print("start_date: ", person.get_infection_state(start_date)) + if (person.get_infection_state(start_date) == mio.abm.InfectionState.Susceptible): + + print("end_date: ", person.get_infection_state(end_date)) + +history.get_log().print_table() diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index c8e4b38133..cb401d7389 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -28,6 +28,8 @@ //Includes from MEmilio #include "abm/simulation.h" #include "abm/household.h" +#include "abm/lockdown_rules.h" +#include "abm/common_abm_loggers.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -132,6 +134,7 @@ PYBIND11_MODULE(_simulation_abm, m) m, "_DistAgeParameterArray"); pymio::bind_CustomIndexArray(m, "_boolAgeParameterArray"); pymio::bind_CustomIndexArray(m, "_intAgeParameterArray"); + pymio::bind_CustomIndexArray(m, "_doubleVirusVariantArray"); pymio::bind_CustomIndexArray(m, "_TestData"); pymio::bind_Index(m, "ProtectionTypeIndex"); pymio::bind_ParameterSet(m, "ParametersBase"); @@ -156,11 +159,21 @@ PYBIND11_MODULE(_simulation_abm, m) &mio::abm::Person::set_assigned_location)) .def_property_readonly("location", py::overload_cast<>(&mio::abm::Person::get_location, py::const_)) .def_property_readonly("age", &mio::abm::Person::get_age) - .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine); + .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine) + .def_property_readonly("id", &mio::abm::Person::get_id) + .def("add_new_infection", + [](mio::abm::Person& self, mio::abm::PersonalRandomNumberGenerator& rng, mio::abm::VirusVariant variant, + mio::AgeGroup age, mio::abm::Parameters& parameters, mio::abm::TimePoint start_date, + mio::abm::InfectionState infection_state) { + self.add_new_infection( + mio::abm::Infection(rng, variant, age, parameters, start_date, infection_state)); + }) + .def("get_infection_state", &mio::abm::Person::get_infection_state); pymio::bind_class(m, "TestingCriteria") .def(py::init&, const std::vector&>(), - py::arg("age_groups"), py::arg("infection_states")); + py::arg("age_groups"), py::arg("infection_states")) + .def(py::init<>()); pymio::bind_class(m, "TestingScheme") .def(py::init(m, "TestingStrategy") .def(py::init&, - const std::vector&>()); + const std::vector&>()) + .def("add_scheme", + py::overload_cast( + &mio::abm::TestingStrategy::add_scheme), + py::arg("location_type"), py::arg("testing_scheme")) + .def("add_scheme", + py::overload_cast( + &mio::abm::TestingStrategy::add_scheme), + py::arg("location_id"), py::arg("testing_scheme")); pymio::bind_class(m, "Location") + .def(py::init(), py::arg("location_type"), py::arg("loc_id")) .def_property_readonly("type", &mio::abm::Location::get_type) .def_property_readonly("id", &mio::abm::Location::get_id) .def_property("infection_parameters", @@ -215,6 +237,8 @@ PYBIND11_MODULE(_simulation_abm, m) py::keep_alive<1, 0>{}) //keep this model alive while contents are referenced in ranges .def_property_readonly("persons", py::overload_cast<>(&mio::abm::Model::get_persons, py::const_), py::keep_alive<1, 0>{}) + .def("get_location", py::overload_cast(&mio::abm::Model::get_location), + py::arg("location_id"), py::return_value_policy::reference_internal) .def_property( "trip_list", py::overload_cast<>(&mio::abm::Model::get_trip_list), [](mio::abm::Model& self, const mio::abm::TripList& list) { @@ -233,9 +257,23 @@ PYBIND11_MODULE(_simulation_abm, m) pymio::bind_class, pymio::EnablePickling::Never>(m, "Simulation") .def(py::init()) - .def("advance", - static_cast::*)(mio::abm::TimePoint)>(&mio::abm::Simulation<>::advance), - py::arg("tmax")) + .def(py::init([](mio::abm::TimePoint t, mio::abm::Model& model) { + return mio::abm::Simulation(t, std::move(model)); + }), + py::return_value_policy::reference_internal) + .def( + "advance", + [](mio::abm::Simulation<>& sim, mio::abm::TimePoint tmax) { + sim.advance(tmax); + }, + py::arg("tmax")) + .def( + "advance", + [](mio::abm::Simulation<>& sim, mio::abm::TimePoint tmax, + mio::History& history) { + sim.advance(tmax, history); + }, + py::arg("tmax"), py::arg("history")) .def_property_readonly("model", py::overload_cast<>(&mio::abm::Simulation<>::get_model)); pymio::bind_class(m, "HouseholdMember") @@ -243,11 +281,35 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property("age_weights", &mio::abm::HouseholdMember::get_age_weights, &mio::abm::HouseholdMember::set_age_weight); - pymio::bind_class(m, "HouseholdGroup").def(py::init<>()); - // .def("add_households", &mio::abm::HouseholdGroup::add_households, - // py::arg("households") = std::vector(), py::arg("num_households") = 1); + pymio::bind_class(m, "HouseholdGroup") + .def(py::init<>()) + .def("add_households", &mio::abm::HouseholdGroup::add_households) + .def_property_readonly("num_households", &mio::abm::HouseholdGroup::get_total_number_of_households); + + pymio::bind_class(m, "Household") + .def(py::init<>()) + .def("add_members", &mio::abm::Household::add_members) + .def_property_readonly("num_members", &mio::abm::Household::get_total_number_of_members); + + m.def("add_household_group_to_model", &mio::abm::add_household_group_to_model); - pymio::bind_class(m, "Household").def(py::init<>()); + pymio::bind_class( + m, "PersonalRandomNumberGenerator") + .def(py::init(), py::arg("person")); + + pymio::bind_class(m, "Infection"); + + m.def("close_social_events", &mio::abm::close_social_events); + + pymio::bind_class, + pymio::EnablePickling::Never>(m, "History") + .def(py::init>()) + .def( + "get_log", + [](mio::History& self) { + return std::get<0>(self.get_log()); + }, + py::return_value_policy::reference_internal); m.attr("__version__") = "dev"; } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp index cf94d3a27e..7d327f9435 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp @@ -162,5 +162,7 @@ PYBIND11_MODULE(_simulation, m) pymio::bind_random_number_generator(m, "RandomNumberGenerator"); + pymio::bind_discrete_distribution(m, "DiscreteDistribution"); + m.attr("__version__") = "dev"; } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/abstract_parameter_distribution.h b/pycode/memilio-simulation/memilio/simulation/bindings/utils/abstract_parameter_distribution.h new file mode 100644 index 0000000000..641e97e9c8 --- /dev/null +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/abstract_parameter_distribution.h @@ -0,0 +1,40 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: Carlotta Gerstein +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "memilio/utils/abstract_parameter_distribution.h" +#include "memilio/utils/random_number_generator.h" +#include "pybind_util.h" + +namespace py = pybind11; + +namespace pymio +{ +void bind_abstract_parameter_distribution(py::module_& m, std::string const& name) +{ + bind_class(m, name.c_str()) + .def(py::init<>()) + .def(py::init(), py::arg("dist")) + .def("get", + [](mio::AbstractParameterDistribution& self, mio::RandomNumberGenerator& rng) { + return self.get(rng); + }) + .def("params", &mio::AbstractParameterDistribution::params); +} +} // namespace pymio \ No newline at end of file diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/random_number_generator.h b/pycode/memilio-simulation/memilio/simulation/bindings/utils/random_number_generator.h new file mode 100644 index 0000000000..568980e73d --- /dev/null +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/random_number_generator.h @@ -0,0 +1,45 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: Carlotta Gerstein +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "memilio/utils/random_number_generator.h" +#include "pybind_util.h" + +namespace py = pybind11; + +namespace pymio +{ +void bind_random_number_generator(py::module_& m, std::string const& name) +{ + bind_class(m, name.c_str()) + .def(py::init<>()) + .def_property_readonly("key", &mio::RandomNumberGenerator::get_key) + .def_property("counter", &mio::RandomNumberGenerator::get_counter, &mio::RandomNumberGenerator::set_counter) + .def_property_readonly("seeds", &mio::RandomNumberGenerator::get_seeds) + .def("increment_counter", &mio::RandomNumberGenerator::increment_counter) + .def("seed", &mio::RandomNumberGenerator::seed, py::arg("seeds")); +} + +void bind_discrete_distribution(py::module_& m, std::string const& name) +{ + bind_class, EnablePickling::Never>(m, name.c_str()) + .def("get_instance", &mio::DiscreteDistribution::get_instance, + py::return_value_policy::reference_internal); +} +} // namespace pymio \ No newline at end of file From cc3ddb2fd7d58d3983b068e95d617ad78db9c968 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:35:46 +0100 Subject: [PATCH 4/9] change to ContactMatrix --- cpp/examples/abm_minimal.cpp | 10 ++++----- cpp/models/abm/model_functions.cpp | 22 +++++++++---------- cpp/models/abm/parameters.h | 4 ++-- .../simulation/abm_minimal_example.py | 11 +++++----- .../simulation/bindings/models/abm.cpp | 5 ++++- 5 files changed, 25 insertions(+), 27 deletions(-) diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index fe0501edfb..d14e1fb5e2 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -99,9 +99,8 @@ int main() // Increase aerosol transmission for all locations model.parameters.get() = 10.0; // Increase contact rate for all people between 15 and 34 (i.e. people meet more often in the same location) - model.get_location(work) - .get_infection_parameters() - .get()[{age_group_15_to_34, age_group_15_to_34}] = 10.0; + model.get_location(work).get_infection_parameters().get().get_baseline()( + (size_t)age_group_15_to_34, (size_t)age_group_15_to_34) = 10.0; // People can get tested at work (and do this with 0.5 probability) from time point 0 to day 10. auto validity_period = mio::abm::days(1); @@ -166,9 +165,8 @@ int main() // The first column is Time. The other columns correspond to the number of people with a certain infection state at this Time: // Time = Time in days, S = Susceptible, E = Exposed, I_NS = InfectedNoSymptoms, I_Sy = InfectedSymptoms, I_Sev = InfectedSevere, // I_Crit = InfectedCritical, R = Recovered, D = Dead - std::ofstream outfile("abm_minimal.txt"); - std::get<0>(historyTimeSeries.get_log()) - .print_table(outfile, {"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4); + // std::ofstream outfile("abm_minimal.txt"); + std::get<0>(historyTimeSeries.get_log()).print_table({"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4); std::cout << "Results written to abm_minimal.txt" << std::endl; return 0; diff --git a/cpp/models/abm/model_functions.cpp b/cpp/models/abm/model_functions.cpp index d1531883c6..9ebb86f483 100644 --- a/cpp/models/abm/model_functions.cpp +++ b/cpp/models/abm/model_functions.cpp @@ -42,15 +42,13 @@ ScalarType total_exposure_by_contacts(const ContactExposureRates& rates, const C if (age_receiver == age_transmitter && age_receiver_group_size > 1) // adjust for the person not meeting themself { - total_exposure += - rates[{cell_index, virus, age_transmitter}] * - params.get()[0].get_baseline()((size_t)age_receiver, (size_t)age_transmitter) * - age_receiver_group_size / (age_receiver_group_size - 1); + total_exposure += rates[{cell_index, virus, age_transmitter}] * + params.get().get_baseline()((size_t)age_receiver, (size_t)age_transmitter) * + age_receiver_group_size / (age_receiver_group_size - 1); } else { - total_exposure += - rates[{cell_index, virus, age_transmitter}] * - params.get()[0].get_baseline()((size_t)age_receiver, (size_t)age_transmitter); + total_exposure += rates[{cell_index, virus, age_transmitter}] * + params.get().get_baseline()((size_t)age_receiver, (size_t)age_transmitter); } } return total_exposure; @@ -191,15 +189,15 @@ void adjust_contact_rates(Location& location, size_t num_agegroups) ScalarType total_contacts = 0.; // slizing would be preferred but is problematic since both Tags of ContactRates are AgeGroup for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { - total_contacts += location.get_infection_parameters().get()[0].get_baseline()( + total_contacts += location.get_infection_parameters().get().get_baseline()( (size_t)contact_from, (size_t)contact_to); } if (total_contacts > location.get_infection_parameters().get()) { for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { - location.get_infection_parameters().get()[0].get_baseline()((size_t)contact_from, - (size_t)contact_to) = - location.get_infection_parameters().get()[0].get_baseline()((size_t)contact_from, - (size_t)contact_to) * + location.get_infection_parameters().get().get_baseline()((size_t)contact_from, + (size_t)contact_to) = + location.get_infection_parameters().get().get_baseline()((size_t)contact_from, + (size_t)contact_to) * location.get_infection_parameters().get() / total_contacts; } } diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index dac769f4cb..daf256736a 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -728,11 +728,11 @@ struct MaximumContacts { * contact rates */ struct ContactRates { - using Type = ContactMatrixGroup; + using Type = ContactMatrix; static Type get_default(AgeGroup size) { return Type( - 1, static_cast((size_t)size)); // amount of contacts from AgeGroup a to AgeGroup b per day + static_cast((size_t)size)); // amount of contacts from AgeGroup a to AgeGroup b per day } static std::string name() { diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py index 3ab99027dc..8c5fc7b2c1 100644 --- a/pycode/examples/simulation/abm_minimal_example.py +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -29,8 +29,6 @@ model = Model(num_age_groups) -# model.rng([1, 2, 3, 4, 5, 6]) - # Set parameters for age_group in range(num_age_groups): @@ -98,7 +96,7 @@ contacts[2, 3] = 10 model.get_location( - work).infection_parameters.ContactRates[0].baseline = contacts + work).infection_parameters.ContactRates.baseline = contacts # Testing Schemes @@ -119,13 +117,14 @@ # Seed infections infection_distribution = [0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0] +rng = np.random.default_rng() for person in model.persons: - infection_state = mio.abm.InfectionState(np.random.choice( + infection_state = mio.abm.InfectionState(rng.choice( len(infection_distribution), p=infection_distribution)) - rng = mio.abm.PersonalRandomNumberGenerator(person) + prng = mio.abm.PersonalRandomNumberGenerator(person) if infection_state != mio.abm.InfectionState.Susceptible: - person.add_new_infection(rng, mio.abm.VirusVariant.Wildtype, + person.add_new_infection(prng, mio.abm.VirusVariant.Wildtype, person.age, model.parameters, start_date, infection_state) for person in model.persons: diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index cb401d7389..85782468fc 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -253,7 +253,10 @@ PYBIND11_MODULE(_simulation_abm, m) [](mio::abm::Model& self, mio::abm::TestingStrategy strategy) { self.get_testing_strategy() = strategy; }, - py::return_value_policy::reference_internal); + py::return_value_policy::reference_internal) + .def("seed_rng", [](mio::abm::Model& self, std::vector seeds) { + self.get_rng().seed(seeds); + }); pymio::bind_class, pymio::EnablePickling::Never>(m, "Simulation") .def(py::init()) From 5851eecf9db9e266fa568b801599650ea2174a88 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:12:16 +0100 Subject: [PATCH 5/9] fix bug --- pycode/examples/simulation/abm_minimal_example.py | 8 +------- .../memilio/simulation/bindings/pybind_util.h | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py index 8c5fc7b2c1..ba86ea18d7 100644 --- a/pycode/examples/simulation/abm_minimal_example.py +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -150,14 +150,8 @@ sim = mio.abm.Simulation(t0, model) -history = mio.abm.History(mio.TimeSeries(num_age_groups)) +history = mio.abm.History(mio.TimeSeries(len(mio.abm.InfectionState.values()))) sim.advance(tmax, history) -for person in sim.model.persons: - # print("start_date: ", person.get_infection_state(start_date)) - if (person.get_infection_state(start_date) == mio.abm.InfectionState.Susceptible): - - print("end_date: ", person.get_infection_state(end_date)) - history.get_log().print_table() diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/pybind_util.h b/pycode/memilio-simulation/memilio/simulation/bindings/pybind_util.h index dc48195817..822ded1839 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/pybind_util.h +++ b/pycode/memilio-simulation/memilio/simulation/bindings/pybind_util.h @@ -255,7 +255,7 @@ auto bind_Range(pybind11::module_& m, const std::string& class_name) .def( "__iter__", [](Range& self) { - return self; + return Iterator{{self.begin(), self.end()}}; }, pybind11::keep_alive<1, 0>{}) //keep alive the Range as long as there is an iterator .def( From 7e352b1f1b9967ad04034373c0a4328002a8a08b Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Thu, 26 Mar 2026 10:59:45 +0100 Subject: [PATCH 6/9] rename History --- pycode/examples/simulation/abm_minimal_example.py | 3 ++- .../memilio/simulation/bindings/models/abm.cpp | 2 +- .../memilio/simulation/bindings/simulation.cpp | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py index ba86ea18d7..f1dd3b6c24 100644 --- a/pycode/examples/simulation/abm_minimal_example.py +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -150,7 +150,8 @@ sim = mio.abm.Simulation(t0, model) -history = mio.abm.History(mio.TimeSeries(len(mio.abm.InfectionState.values()))) +history = mio.abm.TimeSeriesWriterLogInfectionStateHistory( + mio.TimeSeries(len(mio.abm.InfectionState.values()))) sim.advance(tmax, history) diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index 85782468fc..c26ca2f216 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -305,7 +305,7 @@ PYBIND11_MODULE(_simulation_abm, m) m.def("close_social_events", &mio::abm::close_social_events); pymio::bind_class, - pymio::EnablePickling::Never>(m, "History") + pymio::EnablePickling::Never>(m, "TimeSeriesWriterLogInfectionStateHistory") .def(py::init>()) .def( "get_log", diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp index 7d327f9435..cf94d3a27e 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp @@ -162,7 +162,5 @@ PYBIND11_MODULE(_simulation, m) pymio::bind_random_number_generator(m, "RandomNumberGenerator"); - pymio::bind_discrete_distribution(m, "DiscreteDistribution"); - m.attr("__version__") = "dev"; } From 9e7c6aa5e427c746c86ca3cf43d18f80e46c2b54 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:40:45 +0200 Subject: [PATCH 7/9] add bindings for vaccinations --- .../simulation/abm_minimal_example.py | 112 ++++++++++++------ .../bindings/math/time_series_functor.h | 48 ++++++++ .../simulation/bindings/models/abm.cpp | 14 ++- .../simulation/bindings/simulation.cpp | 3 + .../memilio/simulation/bindings/utils/index.h | 1 + .../simulation/bindings/utils/time_series.cpp | 2 + 6 files changed, 140 insertions(+), 40 deletions(-) create mode 100644 pycode/memilio-simulation/memilio/simulation/bindings/math/time_series_functor.h diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py index f1dd3b6c24..95c8a5a9db 100644 --- a/pycode/examples/simulation/abm_minimal_example.py +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -19,20 +19,20 @@ ############################################################################# from memilio.simulation import AgeGroup -from memilio.simulation.abm import Model, VirusVariant +import memilio.simulation.abm as abm import memilio.simulation as mio import numpy as np -from typing import Tuple +import random num_age_groups = 4 -model = Model(num_age_groups) +model = abm.Model(num_age_groups) # Set parameters for age_group in range(num_age_groups): - model.parameters.TimeExposedToNoSymptoms[VirusVariant.Wildtype, mio.AgeGroup(age_group)] = mio.AbstractParameterDistribution(mio.ParameterDistributionLogNormal( + model.parameters.TimeExposedToNoSymptoms[abm.VirusVariant.Wildtype, AgeGroup(age_group)] = mio.AbstractParameterDistribution(mio.ParameterDistributionLogNormal( 4., 1.)) model.parameters.AgeGroupGotoSchool[AgeGroup(age_group)] = False @@ -42,55 +42,64 @@ model.parameters.AgeGroupGotoWork[AgeGroup(2)] = True model.parameters.AgeGroupGotoWork[AgeGroup(3)] = True +for age in range(num_age_groups): + model.parameters.InfectionProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup( + age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor( + [[0, 0.0], [14, 0.67], [180, 0.4]]) + + model.parameters.SeverityProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup( + age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor( + [[0, 0.0], [14, 0.85], [180, 0.7]]) + model.parameters.check_constraints() # Set populations n_households = 10 -child = mio.abm.HouseholdMember(num_age_groups) +child = abm.HouseholdMember(num_age_groups) child.age_weights[AgeGroup(0)] = 1. child.age_weights[AgeGroup(1)] = 1. -parent = mio.abm.HouseholdMember(num_age_groups) +parent = abm.HouseholdMember(num_age_groups) parent.age_weights[AgeGroup(2)] = 1. parent.age_weights[AgeGroup(3)] = 1. -twoPersonHousehold_group = mio.abm.HouseholdGroup() -twoPersonHousehold_full = mio.abm.Household() +twoPersonHousehold_group = abm.HouseholdGroup() +twoPersonHousehold_full = abm.Household() twoPersonHousehold_full.add_members(child, 1) twoPersonHousehold_full.add_members(parent, 1) twoPersonHousehold_group.add_households(twoPersonHousehold_full, n_households) -mio.abm.add_household_group_to_model(model, twoPersonHousehold_group) +abm.add_household_group_to_model(model, twoPersonHousehold_group) -threePersonHousehold_group = mio.abm.HouseholdGroup() -threePersonHousehold_full = mio.abm.Household() +threePersonHousehold_group = abm.HouseholdGroup() +threePersonHousehold_full = abm.Household() threePersonHousehold_full.add_members(child, 1) threePersonHousehold_full.add_members(parent, 2) threePersonHousehold_group.add_households( threePersonHousehold_full, n_households) -mio.abm.add_household_group_to_model(model, threePersonHousehold_group) +abm.add_household_group_to_model(model, threePersonHousehold_group) # Set locations -event = model.add_location(mio.abm.LocationType.SocialEvent) +event = model.add_location(abm.LocationType.SocialEvent) model.get_location(event).infection_parameters.MaximumContacts = 5 -hospital = model.add_location(mio.abm.LocationType.Hospital) +hospital = model.add_location(abm.LocationType.Hospital) model.get_location(hospital).infection_parameters.MaximumContacts = 5 -icu = model.add_location(mio.abm.LocationType.ICU) +icu = model.add_location(abm.LocationType.ICU) model.get_location(icu).infection_parameters.MaximumContacts = 5 -shop = model.add_location(mio.abm.LocationType.BasicsShop) +shop = model.add_location(abm.LocationType.BasicsShop) model.get_location(shop).infection_parameters.MaximumContacts = 20 -school = model.add_location(mio.abm.LocationType.School) +school = model.add_location(abm.LocationType.School) model.get_location(school).infection_parameters.MaximumContacts = 20 -work = model.add_location(mio.abm.LocationType.Work) +work = model.add_location(abm.LocationType.Work) model.get_location(work).infection_parameters.MaximumContacts = 20 -model.parameters.AerosolTransmissionRates[VirusVariant.Wildtype] = 10 +model.parameters.AerosolTransmissionRates[abm.VirusVariant.Wildtype] = 10 contacts = np.zeros((num_age_groups, num_age_groups)) contacts[2, 3] = 10 @@ -100,33 +109,35 @@ # Testing Schemes -validity_period = mio.abm.days(1) +validity_period = abm.days(1) probability = 0.5 -start_date = mio.abm.TimePoint(0) -end_date = mio.abm.TimePoint(0) + mio.abm.days(10) -test_type = mio.abm.TestType.Antigen +start_date = abm.TimePoint(0) +end_date = abm.TimePoint(0) + abm.days(10) +test_type = abm.TestType.Antigen test_parameters = model.parameters.TestData[test_type] -testing_criteria_work = mio.abm.TestingCriteria() -testing_scheme_work = mio.abm.TestingScheme( +testing_criteria_work = abm.TestingCriteria() +testing_scheme_work = abm.TestingScheme( testing_criteria_work, validity_period, start_date, end_date, test_parameters, probability) model.testing_strategy.add_scheme( - mio.abm.LocationType.Work, testing_scheme_work) + abm.LocationType.Work, testing_scheme_work) # Seed infections infection_distribution = [0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0] rng = np.random.default_rng() for person in model.persons: - infection_state = mio.abm.InfectionState(rng.choice( + infection_state = abm.InfectionState(rng.choice( len(infection_distribution), p=infection_distribution)) - prng = mio.abm.PersonalRandomNumberGenerator(person) + prng = abm.PersonalRandomNumberGenerator(person) - if infection_state != mio.abm.InfectionState.Susceptible: - person.add_new_infection(prng, mio.abm.VirusVariant.Wildtype, + if infection_state != abm.InfectionState.Susceptible: + person.add_new_infection(prng, abm.VirusVariant.Wildtype, person.age, model.parameters, start_date, infection_state) +# Assign locations + for person in model.persons: id = person.id @@ -142,16 +153,43 @@ if person.age == AgeGroup(2) or person.age == AgeGroup(3): model.assign_location(id, work) -t_lockdown = mio.abm.TimePoint(0) + mio.abm.days(10) -mio.abm.close_social_events(t_lockdown, 0.9, model.parameters) +# Vaccinations + +vacc_rate = 0.7 +vaccination_priority = [AgeGroup(3), AgeGroup(2), AgeGroup(1)] +vaccination_time = start_date - abm.days(20) + +persons_by_age = [[] for _ in range(num_age_groups)] +for idx, person in enumerate(model.persons): + persons_by_age[person.age.get()].append(idx) + +for age in vaccination_priority: + indices = persons_by_age[age.get()] + + random.shuffle(indices) + + temp = vacc_rate * len(indices) + n_to_vaccinate = int(np.round(vacc_rate * len(indices))) + + count = 0 + for i in range(n_to_vaccinate): + person = model.persons[indices[i]] + if person.get_infection_state(vaccination_time) == abm.InfectionState.Susceptible: + person.add_new_vaccination( + abm.ProtectionType.GenericVaccine, vaccination_time) + +# Simulate + +t_lockdown = start_date + abm.days(10) +abm.close_social_events(t_lockdown, 0.9, model.parameters) -t0 = mio.abm.TimePoint(0) -tmax = t0 + mio.abm.days(10) -sim = mio.abm.Simulation(t0, model) +t0 = start_date +tmax = t0 + abm.days(10) +sim = abm.Simulation(t0, model) -history = mio.abm.TimeSeriesWriterLogInfectionStateHistory( - mio.TimeSeries(len(mio.abm.InfectionState.values()))) +history = abm.TimeSeriesWriterLogInfectionStateHistory( + mio.TimeSeries(len(abm.InfectionState.values()))) sim.advance(tmax, history) diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/math/time_series_functor.h b/pycode/memilio-simulation/memilio/simulation/bindings/math/time_series_functor.h new file mode 100644 index 0000000000..2842b2edff --- /dev/null +++ b/pycode/memilio-simulation/memilio/simulation/bindings/math/time_series_functor.h @@ -0,0 +1,48 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: Carlotta Gerstein +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "memilio/math/time_series_functor.h" +#include "memilio/math/interpolation.h" + +#include "pybind_util.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; + +namespace pymio +{ + +void bind_time_series_functor(py::module_& m, std::string const& name) +{ + bind_class, EnablePickling::Never>(m, name.c_str()) + .def(py::init()) + .def(py::init>()) + .def(py::init([](const mio::TimeSeries& data) { + return mio::TimeSeriesFunctor(mio::TimeSeriesFunctorType::LinearInterpolation, data); + })) + .def(py::init([](std::vector>&& table) { + return mio::TimeSeriesFunctor(mio::TimeSeriesFunctorType::LinearInterpolation, table); + })) + .def("__call__", [](mio::TimeSeriesFunctor& self, double time) { + return self(time); + }); +} + +} // namespace pymio diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index c26ca2f216..952f404343 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -136,11 +136,13 @@ PYBIND11_MODULE(_simulation_abm, m) pymio::bind_CustomIndexArray(m, "_intAgeParameterArray"); pymio::bind_CustomIndexArray(m, "_doubleVirusVariantArray"); pymio::bind_CustomIndexArray(m, "_TestData"); - pymio::bind_Index(m, "ProtectionTypeIndex"); + pymio::bind_CustomIndexArray, mio::abm::ProtectionType, mio::AgeGroup, + mio::abm::VirusVariant>(m, "_ProtectionFactorArray"); pymio::bind_ParameterSet(m, "ParametersBase"); pymio::bind_class(m, "Parameters") .def(py::init()) - .def("check_constraints", &mio::abm::Parameters::check_constraints); + .def("check_constraints", &mio::abm::Parameters::check_constraints) + .def_property_readonly("num_age_groups", &mio::abm::Parameters::get_num_groups); pymio::bind_ParameterSet( m, "LocalInfectionParameters") @@ -168,7 +170,13 @@ PYBIND11_MODULE(_simulation_abm, m) self.add_new_infection( mio::abm::Infection(rng, variant, age, parameters, start_date, infection_state)); }) - .def("get_infection_state", &mio::abm::Person::get_infection_state); + .def("add_new_vaccination", &mio::abm::Person::add_new_vaccination, py::return_value_policy::reference_internal) + // .def("add_new_vaccination", + // [](mio::abm::Person& self, mio::abm::ProtectionType type, mio::abm::TimePoint start_date) { + // self.add_new_vaccination(type, start_date); + // }) + .def("get_infection_state", &mio::abm::Person::get_infection_state) + .def_property_readonly("vaccinations", py::overload_cast<>(&mio::abm::Person::get_vaccinations, py::const_)); pymio::bind_class(m, "TestingCriteria") .def(py::init&, const std::vector&>(), diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp index cf94d3a27e..1da0f33985 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/simulation.cpp @@ -28,6 +28,7 @@ #include "epidemiology/dynamic_npis.h" #include "epidemiology/simulation_day.h" #include "math/integrator.h" +#include "math/time_series_functor.h" #include "mobility/metapopulation_mobility_instant.h" #include "utils/abstract_parameter_distribution.h" #include "utils/date.h" @@ -162,5 +163,7 @@ PYBIND11_MODULE(_simulation, m) pymio::bind_random_number_generator(m, "RandomNumberGenerator"); + pymio::bind_time_series_functor(m, "TimeSeriesFunctor"); + m.attr("__version__") = "dev"; } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/index.h b/pycode/memilio-simulation/memilio/simulation/bindings/utils/index.h index cbaf6f8315..3d497cdfc2 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/index.h +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/index.h @@ -51,6 +51,7 @@ void bind_Index(pybind11::module_& m, std::string const& name) c.def(pybind11::init(), pybind11::arg("value")); c.def(pybind11::self == pybind11::self); c.def(pybind11::self != pybind11::self); + c.def("get", &mio::Index::get); bind_Index_members_if_enum(c); } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/time_series.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/utils/time_series.cpp index 862d22a72f..9ff15c7059 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/time_series.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/time_series.cpp @@ -19,6 +19,7 @@ */ #include "utils/time_series.h" #include "memilio/utils/time_series.h" +#include "memilio/math/time_series_functor.h" #include "pybind_util.h" #include @@ -34,6 +35,7 @@ void bind_time_series(py::module_& m, std::string const& name) { bind_class, EnablePickling::Required>(m, name.c_str()) .def(py::init(), py::arg("num_elements")) + .def(py::init>>(), py::arg("table")) .def("get_num_time_points", &mio::TimeSeries::get_num_time_points) .def("get_num_elements", &mio::TimeSeries::get_num_elements) .def("get_time", py::overload_cast(&mio::TimeSeries::get_time), py::arg("index")) From ef29c59031f5fbd6033e256ce604845e86f77b70 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:38:53 +0200 Subject: [PATCH 8/9] remove debug artifacts --- cpp/examples/abm_minimal.cpp | 5 +++-- cpp/models/abm/parameters.h | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index 662a3afc55..a525e3182f 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -165,8 +165,9 @@ int main() // The first column is Time. The other columns correspond to the number of people with a certain infection state at this Time: // Time = Time in days, S = Susceptible, E = Exposed, I_NS = InfectedNoSymptoms, I_Sy = InfectedSymptoms, I_Sev = InfectedSevere, // I_Crit = InfectedCritical, R = Recovered, D = Dead - // std::ofstream outfile("abm_minimal.txt"); - std::get<0>(historyTimeSeries.get_log()).print_table({"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4); + std::ofstream outfile("abm_minimal.txt"); + std::get<0>(historyTimeSeries.get_log()) + .print_table(outfile, {"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4); std::cout << "Results written to abm_minimal.txt" << std::endl; return 0; diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index daf256736a..1933d36722 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -41,7 +41,6 @@ #include "memilio/epidemiology/age_group.h" #include "memilio/epidemiology/damping.h" #include "memilio/epidemiology/contact_matrix.h" -#include "memilio/epidemiology/uncertain_matrix.h" #include #include From 43cb99672318814c31c2fc63219b5aab2679d627 Mon Sep 17 00:00:00 2001 From: Carlotta Gerstein <100771374+charlie0614@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:00:38 +0200 Subject: [PATCH 9/9] remove personalrng from bindings --- pycode/examples/simulation/abm_minimal_example.py | 3 +-- .../memilio/simulation/bindings/models/abm.cpp | 15 ++++----------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/pycode/examples/simulation/abm_minimal_example.py b/pycode/examples/simulation/abm_minimal_example.py index 95c8a5a9db..c3a0ec61e8 100644 --- a/pycode/examples/simulation/abm_minimal_example.py +++ b/pycode/examples/simulation/abm_minimal_example.py @@ -130,10 +130,9 @@ for person in model.persons: infection_state = abm.InfectionState(rng.choice( len(infection_distribution), p=infection_distribution)) - prng = abm.PersonalRandomNumberGenerator(person) if infection_state != abm.InfectionState.Susceptible: - person.add_new_infection(prng, abm.VirusVariant.Wildtype, + person.add_new_infection(model, abm.VirusVariant.Wildtype, person.age, model.parameters, start_date, infection_state) # Assign locations diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp index 952f404343..99a3f027aa 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/abm.cpp @@ -164,17 +164,14 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine) .def_property_readonly("id", &mio::abm::Person::get_id) .def("add_new_infection", - [](mio::abm::Person& self, mio::abm::PersonalRandomNumberGenerator& rng, mio::abm::VirusVariant variant, - mio::AgeGroup age, mio::abm::Parameters& parameters, mio::abm::TimePoint start_date, + [](mio::abm::Person& self, mio::abm::Model& model, mio::abm::VirusVariant variant, mio::AgeGroup age, + mio::abm::Parameters& parameters, mio::abm::TimePoint start_date, mio::abm::InfectionState infection_state) { + mio::abm::PersonalRandomNumberGenerator person_rng(model.get_rng(), self); self.add_new_infection( - mio::abm::Infection(rng, variant, age, parameters, start_date, infection_state)); + mio::abm::Infection(person_rng, variant, age, parameters, start_date, infection_state)); }) .def("add_new_vaccination", &mio::abm::Person::add_new_vaccination, py::return_value_policy::reference_internal) - // .def("add_new_vaccination", - // [](mio::abm::Person& self, mio::abm::ProtectionType type, mio::abm::TimePoint start_date) { - // self.add_new_vaccination(type, start_date); - // }) .def("get_infection_state", &mio::abm::Person::get_infection_state) .def_property_readonly("vaccinations", py::overload_cast<>(&mio::abm::Person::get_vaccinations, py::const_)); @@ -304,10 +301,6 @@ PYBIND11_MODULE(_simulation_abm, m) m.def("add_household_group_to_model", &mio::abm::add_household_group_to_model); - pymio::bind_class( - m, "PersonalRandomNumberGenerator") - .def(py::init(), py::arg("person")); - pymio::bind_class(m, "Infection"); m.def("close_social_events", &mio::abm::close_social_events);