diff --git a/docs/source/python/m-generation.rst b/docs/source/python/m-generation.rst index bcabac83f1..4206da086e 100644 --- a/docs/source/python/m-generation.rst +++ b/docs/source/python/m-generation.rst @@ -1,6 +1,27 @@ MEmilio Generation =================== +.. note:: + + The ``memilio-generation`` package contains two independent tools: + + * **Model Generator:** generates a C++ compartmental model with Python bindings from a YAML/TOML configuration file. See :doc:`m-modelgenerator`. + * **Bindings Generator:** automatically generates Python bindings from existing C++ model source files using libclang. Described on this page. + +.. toctree:: + :maxdepth: 1 + :hidden: + + m-modelgenerator + +.. _bindings-generator: + +Bindings Generator +------------------ + +Overview +~~~~~~~~ + This package provides an automatic code generator for Python bindings of the MEmilio C++ library. It enables the automatic generation of a part of the :doc:`Python Bindings ` that is common across multiple models. For a particular example, see the SEIR model with its files `oseir.cpp `_ and `oseir.py `_. @@ -13,7 +34,7 @@ The following figure from Chapter 5 outlines the workflow of the generator. Blue Dependencies ----------- +~~~~~~~~~~~~ The package uses the `Clang C++ library `_ and the `LibClang Python library `_ to analyze the C++ code of a model. Both need to be installed and share the same version. @@ -33,7 +54,7 @@ For a successful build, the development libraries for Python need to be installe If you want to try a different version, set your `libclang` version under ``dependencies`` in the `pyproject.toml `_ and change the clang command in ``create_ast`` in `ast.py `_. Usage ------ +~~~~~ During the installation the package creates a compilation database (compile_commands.json) for the models of the `MEmilio C++ Library `_. @@ -59,7 +80,7 @@ To use the visualization run the command: Visualization -------------- +~~~~~~~~~~~~~ The package contains a `Visualization class `_ to display the generated AST. This allows you to visualize the abstract syntax tree (AST) of the C++ model in different ways: @@ -83,7 +104,7 @@ This means that any nodes beyond the specified depth (e.g., all nodes at level 3 Notice that the visualization as a PNG file should not print the whole AST, as it is not possible to display the whole AST in a single image. Development ------------ +~~~~~~~~~~~ When implementing new model features you can follow these steps: diff --git a/docs/source/python/m-modelgenerator.rst b/docs/source/python/m-modelgenerator.rst new file mode 100644 index 0000000000..0e7f655eb2 --- /dev/null +++ b/docs/source/python/m-modelgenerator.rst @@ -0,0 +1,432 @@ +Model Generator +=============== + +.. note:: + + Here, you start with a model specification and get C++ source files and Python bindings as output. + If you already have a C++ model and want to generate Python bindings for it, you can use the + :ref:`Bindings Generator ` (see :doc:`m-generation`). + +The model generator is part of the ``memilio-generation`` package and provides a high-level way to create new +compartmental ODE models for MEmilio from a simple configuration file. Instead of writing C++ code by hand, you +describe your model in a YAML or TOML file and the generator produces all required source files automatically. +C++ knowledge is not required to use the generator, but you can of course edit the generated C++ code afterwards +if you want to add custom features. Additionally, a Python example application is generated that you can run +immediately after generation is done. + +With the following description, we will generate a model that can later be stratified by demography and resolved spatially. The demographic stratification is one-dimensional with a naming of age groups. However, it can equally be used to stratify according to, e.g., sex/gender or income. + +Overview +~~~~~~~~ + +Given a configuration file, the generator produces the following files: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Output file + - Description + * - ``cpp/models//infection_state.h`` + - C++ enum ``InfectionState`` with all compartments + * - ``cpp/models//parameters.h`` + - Parameter structs, ``ParametersBase``, and ``Parameters`` class with constraint checks + * - ``cpp/models//model.h`` + - ``Model`` class with the ``get_flows()`` implementation + * - ``cpp/models//model.cpp`` + - Minimal translation unit (includes ``model.h``) + * - ``cpp/models//CMakeLists.txt`` + - CMake library target for the new model + * - ``pycode/memilio-simulation/memilio/simulation/bindings/models/.cpp`` + - pybind11 module transferring the model to Python + * - ``pycode/examples/simulation/_simple.py`` + - Ready-to-run Python simulation example + +In the above description, `` is a short but representative name provided by the users; not containing any spaces; see below for an example. In addition to the above files, the two existing CMakeLists files +``cpp/CMakeLists.txt`` and ``pycode/memilio-simulation/CMakeLists.txt`` +are generated in place to register the new model. + +Configuration file format +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Both YAML and TOML are supported. For unexperienced users, we recommend YAML as YAML does not require quotes around string values, thus avoiding potential errors in parsing. + +.. note:: + + In TOML, all string values must be enclosed in quotes. + +The configuration file has four sections that are described below. For all names and namings (comments excluded), please do not use spaces. In general, avoid special characters (colons, question marks etc and in German ä, ö, ü; similarly for other languages) except hyphen and underscore. + +model +^^^^^ + +Metadata about the model. For a SEIR model it could look as follows. + +.. code-block:: yaml + + model: + name: SEIR # Human-readable name used in comments and doc-strings + namespace: oseir # In C++, we define a namespace to directly refer to model properties. We suggest to use `o` + a name, all in small letters. + prefix: ode_seir # Used for folder name and installation. We suggest to use the format `ode_` and a name all in small letters. + +infection_states +^^^^^^^^^^^^^^^^ + +A list of compartment names. At least two are required and all names must be unique. +If you check the generated results, an auxiliary ``Count`` compartment is added automatically at the end of the list for convenience of the computation. For the SEIR model, we have the following list. + +.. code-block:: yaml + + infection_states: + - Susceptible + - Exposed + - Infected + - Recovered + +parameters +^^^^^^^^^^ + +A list of model parameters. Each parameter entry will be encapsulated in a particular structure / class. + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``name`` + - yes + - Intuitive parameter structure name, e.g. ``TransmissionProbabilityOnContact`` + * - ``description`` + - yes + - Short but meaningful description used in the code documentation. + * - ``type`` + - yes + - ``probability`` (scalar in [0,1]), ``time`` (positive duration in days), or ``custom`` + * - ``default`` + - yes + - Default value serving as fallback value + * - ``per_age_group`` + - no: Only a single value can be set for the parameter + - ``true`` (default): For each age group, an individual parameter can be set. + * - ``bounds`` + - no: No bound checking or enforcing of the parameter will be done. + - ``[lower, upper]`` -- use ``null`` for an unbound parameter. + +Default value are passed to a function which only serves as a fallback solution if no value is set. If the users pays attention to always set the parameters, the default value can be ignored (i.e. set to a simple value like 0 or 1) + +.. dropdown:: :fa:`gears` Explanations for experienced C++ users + + Each parameter will obtain its own `struct`. Default value are passed to a ``get_default()`` function which + only serves as a fallback solution if no value is set. If stratification by age_groups is desired (`true` value) a + ``CustomIndexArray`` is used, otherwise the parameter will be represented by + MEmilio's custom-built ``UncertainValue`` which acts as a double value but also allows storing a parameter + distribution to sample values from. + +**Built-in types and their bounds:** + +Depending on the type and bounds provided by the user, MEmilio introduces a parameter constraint checking functionalism. +- ``probability``: constraint check enforces ``[0.0, 1.0]`` +- ``time``: constraint check uses the configured ``bounds``. If ``bounds`` are omitted, the default is ``[0.1, null]``. Values below ``0.1`` days are always raised to ``0.1`` days in the generated C++ constraint check to avoid unreasonably short compartment stays that drastically increase ODE solver run time. +- ``custom``: no automatic constraint check is generated + +.. note:: + + When at least one ``infection`` transition is present, a ``ContactPatterns`` parameter is + added to the model **automatically**, you do not need to declare it in the ``parameters`` + list. It stores the (age-stratified) contact frequencies / matrix (``UncertainContactMatrix``) and is used to compute the force of infection. + In the generated Python example and in your own scripts, set it up like this: + + .. code-block:: python + + model.parameters.ContactPatterns.cont_freq_mat[0].baseline = np.ones((num_groups, num_groups)) + model.parameters.ContactPatterns.cont_freq_mat[0].minimum = np.zeros((num_groups, num_groups)) + +The minimum contact pattern is a critical parameter as it defines a minimum contact frequency under which we cannot go below in the simulation, no matter the strictness of a nonpharmaceutical intervention. It should only be set if a good estimation is available. Otherwise, set it to zero. + +The parameters that need to be provided for the SEIR model are as follows. + +.. code-block:: yaml + + parameters: + - name: TransmissionProbabilityOnContact + description: probability of getting infected from a contact + type: probability + default: 1.0 + per_age_group: true + bounds: [0.0, 1.0] + + - name: TimeExposed + description: the latent time in day unit + type: time + default: 5.2 + per_age_group: true + bounds: [0.1, null] + + - name: TimeInfected + description: the infectious time in day unit + type: time + default: 6.0 + per_age_group: true + bounds: [0.1, null] + +transitions +^^^^^^^^^^^ + +In order to allow the on-the-fly computation of newly infected (or hospitalized for more complex models), provide a full list of transitions (or flows) between compartments. Each transition has the following fields: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``from`` + - yes + - Source compartment (must be in ``infection_states``) + * - ``to`` + - yes + - Target compartment (must be in ``infection_states``, must differ from ``from``) + * - ``type`` + - yes + - ``infection``, ``linear``, or ``custom`` + * - ``parameter`` + - for ``infection`` and ``linear`` + - Name of the driving parameter (must be in ``parameters``) + * - ``infectious_state`` + - for ``infection`` + - Compartment whose population drives the force of infection (e.g. ``Infected``). You can pass a single state or a list of states (e.g. ``[InfectedNoSymptoms, InfectedSymptoms]``), in which case their populations are summed in the force of infection. + * - ``custom_formula`` + - no + - Optional hint placed in a ``TODO`` comment in the generated code + +**Transition types:** + +``infection`` + `infection` represents the force-of-infection flow. For age-resolved models, it generates a double loop over contact age groups using the ``ContactPatterns`` contact matrix. The ``ContactPatterns`` parameter is added to the + model automatically when at least one infection transition is present. + + .. math:: + + {S}'_i \leftarrow -\sum_j c_{ij} \cdot \phi \cdot \frac{I_j}{N_j} \cdot S_i + + where :math:`c_{ij}` is the contact rate between age groups *i* and *j*, + :math:`\phi` is the transmission probability, and :math:`N_j` is the total + population of age group *j*. + +``linear`` + The `linear` flow is a simple outflow proportional to the compartment size: + + .. math:: + + {X}'_i \leftarrow -\frac{1}{T_i} \cdot X_i + + where :math:`T_i` is the time parameter for age group *i*. + +``custom`` + For `custom`, a placeholder is inserted into ``get_flows()`` with a ``TODO`` comment. + If ``custom_formula`` is provided, it is shown as a hint next to the placeholder. + **The generated code will not compile until you fill in the expression.** + +For the SEIR model, we have the following transitions: + +.. code-block:: yaml + + transitions: + - from: Susceptible + to: Exposed + type: infection + parameter: TransmissionProbabilityOnContact + infectious_state: Infected + + - from: Exposed + to: Infected + type: linear + parameter: TimeExposed + + - from: Infected + to: Recovered + type: linear + parameter: TimeInfected + +Complete example: SEIR model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following YAML file fully specifies an SEIR model: + +.. code-block:: yaml + + model: + name: SEIR + namespace: oseir + prefix: ode_seir + + infection_states: + - Susceptible + - Exposed + - Infected + - Recovered + + parameters: + - name: TransmissionProbabilityOnContact + description: probability of getting infected from a contact + type: probability + default: 1.0 + per_age_group: true + bounds: [0.0, 1.0] + + - name: TimeExposed + description: the latent time in day unit + type: time + default: 5.2 + per_age_group: true + bounds: [0.1, null] + + - name: TimeInfected + description: the infectious time in day unit + type: time + default: 6.0 + per_age_group: true + bounds: [0.1, null] + + transitions: + - from: Susceptible + to: Exposed + type: infection + parameter: TransmissionProbabilityOnContact + infectious_state: Infected + + - from: Exposed + to: Infected + type: linear + parameter: TimeExposed + + - from: Infected + to: Recovered + type: linear + parameter: TimeInfected + +More example configurations (including an SEIRD model with a ``custom`` transition and a TOML +version of the SEIR model) can be found in +`pycode/examples/modelgenerator/ `_. + +Usage +~~~~~ + +Installation +^^^^^^^^^^^^ + +Install the ``memilio-generation`` package from the repository root: + +.. code-block:: console + + pip install -e pycode/memilio-generation + +The installation registers the ``memilio-modelgenerator`` command and makes the +``memilio.modelgenerator`` Python module available. + +Command-line interface +^^^^^^^^^^^^^^^^^^^^^^ + +The generator is installed as the command ``memilio-modelgenerator``: + +.. code-block:: console + + # Write all files into the MEmilio repository root + memilio-modelgenerator path/to/seir.yaml --output-dir /path/to/memilio + + # Preview the generated files without writing them + memilio-modelgenerator path/to/seir.yaml --preview + + # TOML input works the same way + memilio-modelgenerator path/to/seir.toml --output-dir /path/to/memilio + + # Overwrite an existing model directory (see warning below) + memilio-modelgenerator path/to/seir.yaml --output-dir /path/to/memilio --force + +Python API +^^^^^^^^^^ + +.. code-block:: python + + from memilio.modelgenerator import Generator + + # Load from YAML + gen = Generator.from_yaml("seir.yaml") + + # Load from TOML + gen = Generator.from_toml("seir.toml") + + # Load from a dict (useful in scripts or tests) + gen = Generator.from_dict(raw_dict) + + # Render all files to a dict {relative_path: content} + files = gen.render() + + # Write all files and patch existing CMakeLists + gen.write("/path/to/memilio") + + # Overwrite an existing model directory (see warning below) + gen.write("/path/to/memilio", overwrite=True) + +.. warning:: + + The generator refuses to write into a model directory that already exists + (``cpp/models//``) unless ``overwrite=True`` (Python API) or ``--force`` + (CLI) is passed explicitly. + This guard is intentional: ``prefix`` and ``namespace`` must be unique across the + whole MEmilio repository. Using the same values as an existing model (e.g. + ``prefix: ode_seir``) would replace a handwritten C++ source file of + that model with generated ones. + +After generation +^^^^^^^^^^^^^^^^ + +1. **Fill in custom transitions** (if any): open the generated ``model.h`` and replace the + ``/* YOUR EXPRESSION HERE */`` placeholder with the actual expression before compiling. + +2. **Compile the model** by building the MEmilio C++ library as usual (CMake). + The patched ``cpp/CMakeLists.txt`` picks up the new model directory automatically. + See :doc:`/cpp/installation` for details on configuring and building with CMake. + +3. **Install the Python bindings** by reinstalling ``memilio-simulation``: + + .. code-block:: console + + pip install -e pycode/memilio-simulation + +4. **Run the generated example**: + + .. code-block:: console + + python pycode/examples/simulation/_simple.py + +Validation +~~~~~~~~~~ + +The generator validates the configuration before any code is produced. +All errors are collected and reported together. + +Common validation errors: + +* Missing or empty ``model``, ``infection_states``, ``parameters``, or ``transitions`` section +* Fewer than two infection states, or duplicate state names +* Parameter ``type`` is not one of ``probability``, ``time``, ``custom`` +* ``parameter`` or ``infectious_state`` / ``infectious_states`` in a transition references an unknown name +* A transition has the same ``from`` and ``to`` state (self-loop) + +Development and extension +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Adding a new transition type or template feature: + +1. Add the new type constant to ``TransitionType`` in + `schema.py `_. +2. Add validation logic to + `validator.py `_. +3. Update the relevant Jinja2 templates under + `templates/ `_. +4. Update the tests in + `tests/test_modelgenerator.py `_. diff --git a/pycode/examples/modelgenerator/seir.toml b/pycode/examples/modelgenerator/seir.toml new file mode 100644 index 0000000000..ec7eb8afc1 --- /dev/null +++ b/pycode/examples/modelgenerator/seir.toml @@ -0,0 +1,49 @@ +infection_states = ["Susceptible", "Exposed", "Infected", "Recovered"] + +[model] +name = "SEIR" +namespace = "oseir" +prefix = "ode_seir" + +[[parameters]] +name = "TransmissionProbabilityOnContact" +description = "probability of getting infected from a contact" +type = "probability" +default = 1 +per_age_group = true +bounds = [0, 1] + +[[parameters]] +name = "TimeExposed" +description = "the latent time in day unit" +type = "time" +default = 5.2 +per_age_group = true +bounds = [0.1, 1e100] + +[[parameters]] +name = "TimeInfected" +description = "the infectious time in day unit" +type = "time" +default = 6 +per_age_group = true +bounds = [0.1, 1e100] + +[[transitions]] +from = "Susceptible" +to = "Exposed" +type = "infection" +parameter = "TransmissionProbabilityOnContact" +infectious_state = "Infected" + +[[transitions]] +from = "Exposed" +to = "Infected" +type = "linear" +parameter = "TimeExposed" + +[[transitions]] +from = "Infected" +to = "Recovered" +type = "linear" +parameter = "TimeInfected" diff --git a/pycode/examples/modelgenerator/seir.yaml b/pycode/examples/modelgenerator/seir.yaml new file mode 100644 index 0000000000..414cecf070 --- /dev/null +++ b/pycode/examples/modelgenerator/seir.yaml @@ -0,0 +1,49 @@ +model: + name: SEIR + namespace: oseir + prefix: ode_seir + +infection_states: + - Susceptible + - Exposed + - Infected + - Recovered + +parameters: + - name: TransmissionProbabilityOnContact + description: probability of getting infected from a contact + type: probability + default: 1.0 + per_age_group: true + bounds: [0.0, 1.0] + + - name: TimeExposed + description: the latent time in day unit + type: time + default: 5.2 + per_age_group: true + bounds: [0.1, null] + + - name: TimeInfected + description: the infectious time in day unit + type: time + default: 6.0 + per_age_group: true + bounds: [0.1, null] + +transitions: + - from: Susceptible + to: Exposed + type: infection + parameter: TransmissionProbabilityOnContact + infectious_state: Infected + + - from: Exposed + to: Infected + type: linear + parameter: TimeExposed + + - from: Infected + to: Recovered + type: linear + parameter: TimeInfected diff --git a/pycode/examples/modelgenerator/seird.yaml b/pycode/examples/modelgenerator/seird.yaml new file mode 100644 index 0000000000..936c9723e9 --- /dev/null +++ b/pycode/examples/modelgenerator/seird.yaml @@ -0,0 +1,62 @@ +model: + name: SEIRD + namespace: oseird + prefix: ode_seird + +infection_states: + - Susceptible + - Exposed + - Infected + - Recovered + - Dead + +parameters: + - name: TransmissionProbabilityOnContact + description: probability of getting infected from a contact + type: probability + default: 1.0 + per_age_group: true + bounds: [0.0, 1.0] + + - name: TimeExposed + description: the latent time in day unit + type: time + default: 5.2 + per_age_group: true + bounds: [0.1, null] + + - name: TimeInfected + description: the infectious time in day unit + type: time + default: 6.0 + per_age_group: true + bounds: [0.1, null] + + - name: DeathRate + description: daily probability of dying while infected + type: probability + default: 0.01 + per_age_group: true + bounds: [0.0, 1.0] + +transitions: + - from: Susceptible + to: Exposed + type: infection + parameter: TransmissionProbabilityOnContact + infectious_state: Infected + + - from: Exposed + to: Infected + type: linear + parameter: TimeExposed + + - from: Infected + to: Recovered + type: linear + parameter: TimeInfected + + - from: Infected + to: Dead + type: custom + custom_formula: "DeathRate[i] * y[idx_Infected_i]" diff --git a/pycode/memilio-generation/README.md b/pycode/memilio-generation/README.md index 8f60e1dd4e..9cd4d30dc2 100644 --- a/pycode/memilio-generation/README.md +++ b/pycode/memilio-generation/README.md @@ -1,6 +1,12 @@ -# MEmilio Automatic Code Generation of Python Bindings +# MEmilio Generation -This package contains Python bindings generating code for the MEmilio C++ library. +This package contains two tools: the **Bindings Generator** (automatically generates Python bindings from existing C++ models) and the **Model Generator** (generates a complete C++ compartmental model and Python bindings from a YAML/TOML specification). + +For full documentation see the [readthedocs pages](https://memilio.readthedocs.io/en/latest/python/m-generation.html): [Bindings Generator](https://memilio.readthedocs.io/en/latest/python/m-generation.html) | [Model Generator](https://memilio.readthedocs.io/en/latest/python/m-modelgenerator.html). + +## Bindings Generator + +This part contains Python bindings generating code for the MEmilio C++ library. It enables the automatic generation of a part of the [Python Bindings](../memilio-simulation/README.md) that is common across multiple models. For a particular example, see the SEIR model with its files `oseir.cpp` and `oseir.py`. This generating software was developed as a part of the Bachelor thesis [Automatische Codegenerierung für nutzerfreundliche mathematisch-epidemiologische Modelle](https://elib.dlr.de/190367/). The following figure from Chapter 5 outlines the workflow of the generator. Blue boxes represent parts of the code generator and orange ones the input and output. Rectangular boxes contain classes with logic, the rest represent data. diff --git a/pycode/memilio-generation/memilio/modelgenerator/__init__.py b/pycode/memilio-generation/memilio/modelgenerator/__init__.py new file mode 100644 index 0000000000..72187c0329 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/__init__.py @@ -0,0 +1,36 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +MEmilio Model Generator. + +Automatically generates C++ source files and pybind11 bindings for MEmilio ODE +compartment models from a YAML configuration file. + +Example:: + + from memilio.modelgenerator import Generator + gen = Generator.from_yaml("examples/seir.yaml") + gen.write(output_dir="/path/to/memilio") +""" + +from .schema import ModelConfig, ParameterConfig, TransitionConfig +from .generator import Generator +from .validator import Validator diff --git a/pycode/memilio-generation/memilio/modelgenerator/cli.py b/pycode/memilio-generation/memilio/modelgenerator/cli.py new file mode 100644 index 0000000000..24e72e8155 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/cli.py @@ -0,0 +1,135 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Command-line interface for the model generator. + +Usage:: + + memilio-modelgenerator path/to/model.yaml [--output-dir DIR] [--preview] +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +from .generator import Generator +from .validator import ValidationError + + +def main(argv=None) -> int: + """ + Run the model-generator command-line interface. + + :param argv: Optional argument vector. If ``None``, ``sys.argv`` is used. + :type argv: list[str] | None + :returns: Process exit code. + :rtype: int + """ + parser = argparse.ArgumentParser( + prog="memilio-modelgenerator", + description="Generate MEmilio C++ model files and pybind11 bindings from a YAML config.", + ) + parser.add_argument( + "config", + metavar="CONFIG", + help="Path to the YAML model configuration file.", + ) + parser.add_argument( + "--output-dir", + metavar="DIR", + default=None, + help=( + "Root directory of the MEmilio repository where files are written. " + "Defaults to the directory two levels above this package " + "(i.e. the repository root when installed in editable mode)." + ), + ) + parser.add_argument( + "--preview", + action="store_true", + help="Print all generated file contents instead of writing them to disk.", + ) + parser.add_argument( + "--force", + action="store_true", + help=( + "Overwrite existing model files. By default the generator refuses to " + "write into an already existing model directory to prevent accidentally " + "overwriting existing, handwritten C++ code." + ), + ) + + args = parser.parse_args(argv) + + try: + if args.config.endswith(".toml"): + gen = Generator.from_toml(args.config) + else: + gen = Generator.from_yaml(args.config) + except FileNotFoundError: + print(f"ERROR: config file not found: {args.config}", file=sys.stderr) + return 1 + except ValidationError as exc: + print(str(exc), file=sys.stderr) + return 1 + + if args.preview: + output_dir = Path(args.output_dir) if args.output_dir else Path( + __file__).resolve().parents[4] + separator = "=" * 72 + for rel_path, content in gen.render().items(): + print(f"\n{separator}") + print(f" NEW FILE: {rel_path}") + print(separator) + print(content) + for rel_path, content in gen.render_patches(output_dir).items(): + if content is None: + print(f"\n{separator}") + print(f" PATCH (already present – no change): {rel_path}") + print(separator) + else: + print(f"\n{separator}") + print(f" PATCH: {rel_path}") + print(separator) + print(content) + return 0 + + # Output directory + if args.output_dir is not None: + output_dir = Path(args.output_dir) + else: + # Go up from this file to the repo root + output_dir = Path(__file__).resolve().parents[4] + + print(f"Writing model files to: {output_dir}") + try: + gen.write(output_dir, overwrite=args.force) + except FileExistsError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + print("Done.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pycode/memilio-generation/memilio/modelgenerator/generator.py b/pycode/memilio-generation/memilio/modelgenerator/generator.py new file mode 100644 index 0000000000..afb24d3a70 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/generator.py @@ -0,0 +1,337 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Core generator: parses a YAML config, builds the internal `ModelConfig` +representation, and renders all Jinja2 templates into strings. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Dict, Optional +import re + +import yaml + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + +if sys.version_info >= (3, 9): + import importlib.resources as importlib_resources +else: + import importlib_resources + +from jinja2 import Environment, PackageLoader, StrictUndefined + +from .schema import ( + ModelConfig, + ModelMeta, + ParameterConfig, + ParameterType, + TransitionConfig, + TransitionType, +) +from .validator import Validator + + +class Generator: + """Parse model configurations and render model source files.""" + + def __init__(self, config: ModelConfig): + """ + Initialize a generator from a validated model configuration. + + :param config: Fully validated model configuration. + :type config: ModelConfig + """ + self._config = config + self._env = Environment( + loader=PackageLoader("memilio.modelgenerator", "templates"), + undefined=StrictUndefined, + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + ) + + @classmethod + def from_yaml(cls, yaml_path: str | Path) -> Generator: + """ + Build a generator from a YAML configuration file. + + :param yaml_path: Path to a ``.yaml`` configuration file. + :type yaml_path: str | Path + :returns: Generator initialized from the parsed YAML configuration. + :rtype: Generator + """ + with open(yaml_path, encoding="utf-8") as fh: + raw = yaml.safe_load(fh) + + Validator.validate(raw) + config = cls._parse(raw) + return cls(config) + + @classmethod + def from_toml(cls, toml_path: str | Path) -> Generator: + """ + Build a generator from a TOML configuration file. + + :param toml_path: Path to a ``.toml`` configuration file. + :type toml_path: str | Path + :returns: Generator initialized from the parsed TOML configuration. + :rtype: Generator + """ + with open(toml_path, "rb") as fh: + raw = tomllib.load(fh) + + Validator.validate(raw) + config = cls._parse(raw) + return cls(config) + + @classmethod + def from_dict(cls, raw: dict) -> Generator: + """ + Build a generator from an already loaded dictionary. + + :param raw: Dictionary as returned by ``yaml.safe_load``. + :type raw: dict + :returns: Generator initialized from ``raw``. + :rtype: Generator + """ + Validator.validate(raw) + config = cls._parse(raw) + return cls(config) + + def render(self) -> dict[str, str]: + """ + Render all generated files. + + Keys in the returned mapping are paths relative to the MEmilio + repository root. Use :meth:`render_patches` for in-place edits of + existing CMake files. + + :returns: Mapping from relative output path to rendered file content. + :rtype: dict[str, str] + """ + cfg = self._config + prefix = cfg.meta.prefix + + return { + f"cpp/models/{prefix}/infection_state.h": self._render("infection_state_h.jinja2"), + f"cpp/models/{prefix}/parameters.h": self._render("parameters_h.jinja2"), + f"cpp/models/{prefix}/model.h": self._render("model_h.jinja2"), + f"cpp/models/{prefix}/model.cpp": self._render("model_cpp.jinja2"), + f"cpp/models/{prefix}/CMakeLists.txt": self._render("CMakeLists_model_txt.jinja2"), + ( + f"pycode/memilio-simulation/memilio/simulation/bindings/models/{prefix}.cpp" + ): self._render("pybindings_cpp.jinja2"), + f"pycode/examples/simulation/{prefix}_simple.py": self._render("example_py.jinja2"), + ( + f"pycode/memilio-simulation/memilio/simulation/{cfg.meta.namespace}.py" + ): self._render("simulation_py.jinja2"), + } + + _CPP_CMAKE = "cpp/CMakeLists.txt" + _SIM_CMAKE = "pycode/memilio-simulation/CMakeLists.txt" + _SIM_INIT = "pycode/memilio-simulation/memilio/simulation/__init__.py" + + def render_patches(self, output_dir: Path) -> dict[str, str | None]: + """ + Compute patch content for existing project files. + + ``None`` in the result means no change is needed because the entry + already exists. + + :param output_dir: Repository root directory. + :type output_dir: Path + :returns: Mapping from relative path to patched content or ``None``. + :rtype: dict[str, str | None] + """ + prefix = self._config.meta.prefix + namespace = self._config.meta.namespace + results: dict[str, str | None] = {} + + # cpp/CMakeLists.txt + cpp_cmake = output_dir / self._CPP_CMAKE + if cpp_cmake.exists(): + text = cpp_cmake.read_text(encoding="utf-8") + entry = f" add_subdirectory(models/{prefix})" + if entry not in text: + # Insert after the last add_subdirectory(models/…) line + pattern = r"( add_subdirectory\(models/[^)]+\))(?!.*add_subdirectory\(models/)" + m = re.search(pattern, text, re.DOTALL) + if m: + insert_at = m.end() + text = text[:insert_at] + "\n" + entry + text[insert_at:] + results[self._CPP_CMAKE] = text + else: + results[self._CPP_CMAKE] = None # already present + + # pycode/memilio-simulation/CMakeLists.txt + sim_cmake = output_dir / self._SIM_CMAKE + if sim_cmake.exists(): + text = sim_cmake.read_text(encoding="utf-8") + module_name = f"_simulation_{namespace}" + block = ( + f"add_pymio_module({module_name}\n" + f" LINKED_LIBRARIES memilio {prefix}\n" + f" SOURCES memilio/simulation/bindings/models/{prefix}.cpp\n" + f")") + if f"add_pymio_module({module_name}\n" not in text: + # Insert before the "# install all shared" comment + marker = "# install all shared memilio libraries" + text = text.replace(marker, block + "\n\n" + marker) + results[self._SIM_CMAKE] = text + else: + results[self._SIM_CMAKE] = None # already present + + # pycode/memilio-simulation/memilio/simulation/__init__.py + sim_init = output_dir / self._SIM_INIT + if sim_init.exists(): + text = sim_init.read_text(encoding="utf-8") + lazy_entry = ( + f" elif attr == \"{namespace}\":\n" + f" import memilio.simulation.{namespace} as {namespace}\n" + f" return {namespace}\n" + ) + if f'attr == "{namespace}"' not in text: + text = text.replace( + " raise AttributeError", + lazy_entry + " raise AttributeError" + ) + results[self._SIM_INIT] = text + else: + results[self._SIM_INIT] = None # already present + + return results + + def write(self, output_dir: str | Path, overwrite: bool = False) -> None: + """ + Write rendered files and apply required project-file patches. + + Directories are created as needed. + + :param output_dir: Root of the target MEmilio repository. + :type output_dir: str | Path + :param overwrite: Allow overwriting an existing model directory. + :type overwrite: bool + :raises FileExistsError: If model directory exists and ``overwrite`` is + ``False``. + """ + output_dir = Path(output_dir) + prefix = self._config.meta.prefix + model_dir = output_dir / "cpp" / "models" / prefix + if model_dir.exists() and not overwrite: + raise FileExistsError( + f"Model directory already exists: {model_dir}\n" + f"Pass overwrite=True (or --force on the CLI) to overwrite it." + ) + + for rel_path, content in self.render().items(): + target = output_dir / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + print(f" wrote {rel_path}") + + for rel_path, content in self.render_patches(output_dir).items(): + if content is None: + print(f" skip {rel_path} (entry already present)") + else: + (output_dir / rel_path).write_text(content, encoding="utf-8") + print(f" patched {rel_path}") + + def _render(self, template_name: str) -> str: + tmpl = self._env.get_template(template_name) + return tmpl.render(cfg=self._config) + + @staticmethod + def _parse(raw: dict) -> ModelConfig: + meta = ModelMeta( + name=raw["model"]["name"], + namespace=raw["model"]["namespace"], + prefix=raw["model"]["prefix"], + ) + + states: list[str] = raw["infection_states"] + + parameters = [] + for p in raw["parameters"]: + bounds_raw = p.get("bounds") + if bounds_raw is not None: + lower = bounds_raw[0] + upper = bounds_raw[1] + if p["type"] == ParameterType.TIME: + lower = 1e-1 if lower is None else max(float(lower), 1e-1) + bounds = (lower, upper) + else: + if p["type"] == ParameterType.PROBABILITY: + bounds = (0.0, 1.0) + elif p["type"] == ParameterType.TIME: + bounds = (1e-1, None) + else: + bounds = (None, None) + + parameters.append( + ParameterConfig( + name=p["name"], + description=p.get("description", ""), + type=p["type"], + default=float(p["default"]), + per_age_group=bool(p.get("per_age_group", True)), + bounds=bounds, + ) + ) + + transitions = [] + for t in raw["transitions"]: + raw_infectious_states = t.get("infectious_states") + raw_infectious_state = t.get("infectious_state") + if isinstance(raw_infectious_states, list): + infectious_states = list(raw_infectious_states) + elif raw_infectious_states is not None: + infectious_states = [raw_infectious_states] + elif isinstance(raw_infectious_state, list): + infectious_states = list(raw_infectious_state) + elif raw_infectious_state is not None: + infectious_states = [raw_infectious_state] + else: + infectious_states = [] + + transitions.append( + TransitionConfig( + from_state=t["from"], + to_state=t["to"], + type=t["type"], + parameter=t.get("parameter"), + infectious_state=infectious_states[0] if infectious_states else None, + infectious_states=infectious_states, + custom_formula=t.get("custom_formula"), + ) + ) + + return ModelConfig( + meta=meta, + infection_states=states, + parameters=parameters, + transitions=transitions, + ) diff --git a/pycode/memilio-generation/memilio/modelgenerator/schema.py b/pycode/memilio-generation/memilio/modelgenerator/schema.py new file mode 100644 index 0000000000..aa5954cf78 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/schema.py @@ -0,0 +1,171 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Dataclass definitions that represent a parsed model configuration. + +These are the internal representations produced by parsing a YAML file. +The `Generator` consumes these objects and passes them to the +Jinja2 templates. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + + +# Transition types + +class TransitionType: + """Symbolic constants for the supported flow types.""" + INFECTION = "infection" + """Force-of-infection flow using contact matrix and S*I/N.""" + LINEAR = "linear" + """Simple outflow: (1 / parameter) * source_compartment.""" + CUSTOM = "custom" + """Placeholder. User must supply the expression manually.""" + + ALL = (INFECTION, LINEAR, CUSTOM) + + +class ParameterType: + """Symbolic constants for built-in parameter storage types.""" + PROBABILITY = "probability" + """Scalar in [0, 1] per age group; stored as UncertainValue.""" + TIME = "time" + """Positive duration in days per age group; stored as UncertainValue.""" + CUSTOM = "custom" + """User-defined; no automatic constraint check is generated.""" + + ALL = (PROBABILITY, TIME, CUSTOM) + + +@dataclass +class ModelMeta: + + name: str + """Human-readable model name, e.g. ``"SEIR"``.""" + + namespace: str + """Inner C++ namespace, e.g. ``"oseir"`` → ``mio::oseir``.""" + + prefix: str + """Folder and CMake target prefix, e.g. ``"ode_seir"``.""" + + @property + def guard_prefix(self) -> str: + """Upper-case version of ``prefix`` used in include guards.""" + return self.prefix.upper() + + +@dataclass +class ParameterConfig: + """Configuration for a single model parameter.""" + + name: str + """C++ struct name, e.g. ``"TransmissionProbabilityOnContact"``.""" + + description: str + """Short description used in the Doxygen comment.""" + + type: str + """One of `ParameterType.ALL`.""" + + default: float + """Default value passed to ``get_default``.""" + + per_age_group: bool = True + """If ``True`` the storage type is ``CustomIndexArray, AgeGroup>``.""" + + bounds: tuple[float | None, float | None] = field( + default_factory=lambda: (None, None)) + """(lower, upper) bounds used in the constraint checks. ``None`` means unchecked.""" + + +@dataclass +class TransitionConfig: + """Configuration for a single compartment flow.""" + + from_state: str + """Source compartment name.""" + + to_state: str + """Target compartment name.""" + + type: str + """One of `TransitionType.ALL`.""" + + parameter: str | None = None + """Name of the `ParameterConfig` that drives this flow.""" + + infectious_state: str | None = None + """ + For ``type == "infection"``: the compartment whose population drives + infection (typically ``"Infected"``). Kept as a compatibility alias + for the first entry of ``infectious_states``. + """ + + infectious_states: list[str] = field(default_factory=list) + """ + For ``type == "infection"``: list of compartments whose populations + are summed to drive infection. + """ + + custom_formula: str | None = None + """ + For ``type == "custom"``: an optional hint that is placed in a + ``TODO`` comment next to the placeholder. + """ + + +@dataclass +class ModelConfig: + """Complete parsed model configuration.""" + + meta: ModelMeta + infection_states: list[str] + parameters: list[ParameterConfig] + transitions: list[TransitionConfig] + + @property + def has_infection_transition(self) -> bool: + """``True`` if at least one transition uses the force-of-infection.""" + return any(t.type == TransitionType.INFECTION for t in self.transitions) + + @property + def all_parameters(self) -> list[ParameterConfig]: + """ + Full parameter list including the implicitly added + ``ContactPatterns`` when any infection transition is present. + """ + if not self.has_infection_transition: + return self.parameters + # ContactPatterns is added at the end; the generator inserts it + # directly into the template so we only expose the user-defined ones + # here. The template accesses has_infection_transition separately. + return self.parameters + + def parameters_for_constraint_check(self) -> list[ParameterConfig]: + """Return parameters that have explicit bound constraints.""" + return [ + p for p in self.parameters + if p.type in (ParameterType.PROBABILITY, ParameterType.TIME) + ] diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/CMakeLists_model_txt.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/CMakeLists_model_txt.jinja2 new file mode 100644 index 0000000000..73d9de0192 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/CMakeLists_model_txt.jinja2 @@ -0,0 +1,12 @@ +add_library({{ cfg.meta.prefix }} + infection_state.h + model.h + model.cpp + parameters.h +) +target_link_libraries({{ cfg.meta.prefix }} PUBLIC memilio) +target_include_directories({{ cfg.meta.prefix }} PUBLIC + $ + $ +) +target_compile_options({{ cfg.meta.prefix }} PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS}) diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/example_py.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/example_py.jinja2 new file mode 100644 index 0000000000..45a947ace9 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/example_py.jinja2 @@ -0,0 +1,95 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: generated by memilio-modelgenerator +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Example simulation of the {{ cfg.meta.name }} model. + +Generated by memilio-modelgenerator, edit as needed. +""" + +import numpy as np + +from memilio.simulation import AgeGroup +{% if cfg.has_infection_transition %} +from memilio.simulation import Damping +{% endif %} +from memilio.simulation.{{ cfg.meta.namespace }} import InfectionState as State +from memilio.simulation.{{ cfg.meta.namespace }} import ( + Model, simulate, interpolate_simulation_result) + +{% set first_state = cfg.infection_states[0] %} +{% set second_state = cfg.infection_states[1] %} + + +def run_simulation(t0=0.0, tmax=10.0, dt=0.1): + """Run a {{ cfg.meta.name }} simulation and print a result table. + + :param t0: Start time in days. + :type t0: float + :param tmax: End time in days. + :type tmax: float + :param dt: Integration step size in days. + :type dt: float + """ + num_groups = 1 + model = Model(num_groups) + A0 = AgeGroup(0) + + total_population = 83_000 + + # Parameters – default values from model specification +{% for p in cfg.parameters %} + model.parameters.{{ p.name }}[A0] = {{ p.default }} +{% endfor %} +{% if cfg.has_infection_transition %} + + # Contact patterns (one-group baseline, no dampings) + model.parameters.ContactPatterns.cont_freq_mat[0].baseline = np.ones( + (num_groups, num_groups)) + model.parameters.ContactPatterns.cont_freq_mat[0].minimum = np.zeros( + (num_groups, num_groups)) +{% endif %} + + # Initial conditions with 100 people in {{ second_state }} + model.populations[A0, State.{{ second_state }}] = 100 + model.populations.set_difference_from_total( + (A0, State.{{ first_state }}), total_population) + + model.check_constraints() + + # Simulate + result = simulate(t0, tmax, dt, model) + result = interpolate_simulation_result(result) + + # Print table + states = [{% for s in cfg.infection_states %}"{{ s }}"{% if not loop.last %}, {% endif %}{% endfor %}] + col_w = 14 + header = f"{'t':>6} " + " ".join(f"{s:>{col_w}}" for s in states) + print(header) + print("-" * len(header)) + for i in range(result.get_num_time_points()): + t = result.get_time(i) + vals = result.get_value(i) + row = f"{t:>6.1f} " + " ".join(f"{v:>{col_w}.2f}" for v in vals) + print(row) + + +if __name__ == "__main__": + run_simulation() diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/infection_state_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/infection_state_h.jinja2 new file mode 100644 index 0000000000..b4f5f44038 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/infection_state_h.jinja2 @@ -0,0 +1,43 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: generated by memilio-modelgenerator +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef {{ cfg.meta.guard_prefix }}_INFECTIONSTATE_H +#define {{ cfg.meta.guard_prefix }}_INFECTIONSTATE_H + +namespace mio +{ +namespace {{ cfg.meta.namespace }} +{ + +/** + * @brief The InfectionState enum describes the possible categories + * for the infectious state of persons in the {{ cfg.meta.name }} model. + */ +enum class InfectionState +{ +{% for state in cfg.infection_states %} + {{ state }}, +{% endfor %} + Count +}; + +} // namespace {{ cfg.meta.namespace }} +} // namespace mio + +#endif // {{ cfg.meta.guard_prefix }}_INFECTIONSTATE_H diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/model_cpp.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/model_cpp.jinja2 new file mode 100644 index 0000000000..aeb7edbbc7 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/model_cpp.jinja2 @@ -0,0 +1,20 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: generated by memilio-modelgenerator +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "{{ cfg.meta.prefix }}/model.h" diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 new file mode 100644 index 0000000000..52cd5fcbac --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/model_h.jinja2 @@ -0,0 +1,189 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: generated by memilio-modelgenerator +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef {{ cfg.meta.guard_prefix }}_MODEL_H +#define {{ cfg.meta.guard_prefix }}_MODEL_H + +#include "memilio/compartments/flow_model.h" +#include "memilio/config.h" +#include "memilio/epidemiology/age_group.h" +#include "memilio/epidemiology/populations.h" +#include "memilio/utils/time_series.h" +#include "{{ cfg.meta.prefix }}/infection_state.h" +#include "{{ cfg.meta.prefix }}/parameters.h" + +GCC_CLANG_DIAGNOSTIC(push) +GCC_CLANG_DIAGNOSTIC(ignored "-Wshadow") +#include +GCC_CLANG_DIAGNOSTIC(pop) + +namespace mio +{ +namespace {{ cfg.meta.namespace }} +{ + +/******************** + * define the model * + ********************/ + +// clang-format off +using Flows = TypeList< +{% for t in cfg.transitions %} + Flow{% if not loop.last %},{% endif %} + +{% endfor %} +>; +// clang-format on + +template +class Model + : public FlowModel, Parameters, Flows> +{ + using Base = + FlowModel, Parameters, Flows>; + +public: + using typename Base::ParameterSet; + using typename Base::Populations; + + Model(const Populations& pop, const ParameterSet& params) + : Base(pop, params) + { + } + + Model(int num_agegroups) + : Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), + ParameterSet(AgeGroup(num_agegroups))) + { + } + + void get_flows(Eigen::Ref> pop, Eigen::Ref> y, FP t, + Eigen::Ref> flows) const override + { + const Index age_groups = reduce_index>(this->populations.size()); + const auto& params = this->parameters; + + for (auto i : make_index_range(age_groups)) { + // Flat indices for age group i +{% for state in cfg.infection_states %} + const size_t idx_{{ state }}_i = + this->populations.get_flat_index({i, InfectionState::{{ state }}}); +{% endfor %} + +{% if cfg.has_infection_transition %} + // ---------------------------------------------------------------- + // Infection transitions – double loop over contact age groups + // ---------------------------------------------------------------- +{% for t in cfg.transitions if t.type == 'infection' %} + for (auto j : make_index_range(age_groups)) { + // Flat indices for age group j +{% for state in cfg.infection_states %} + const size_t idx_{{ state }}_j = + this->populations.get_flat_index({j, InfectionState::{{ state }}}); +{% endfor %} + + // Total population of age group j + const FP Nj = +{% for state in cfg.infection_states %} + pop[idx_{{ state }}_j]{% if not loop.last %} +{% endif %} + +{% endfor %} + ; + const FP divNj = (Nj < Limits::zero_tolerance()) ? FP(0.0) : FP(1.0 / Nj); + + const FP coeff_{{ t.from_state }}_to_{{ t.to_state }} = + params.template get>() + .get_cont_freq_mat() + .get_matrix_at(SimulationTime(t))(i.get(), j.get()) * + params.template get<{{ t.parameter }}>()[i] * divNj; + + flows[Base::template get_flat_flow_index< + InfectionState::{{ t.from_state }}, + InfectionState::{{ t.to_state }}>(i)] += + coeff_{{ t.from_state }}_to_{{ t.to_state }} * + y[idx_{{ t.from_state }}_i] * ( +{% for inf_state in t.infectious_states %} + pop[idx_{{ inf_state }}_j]{% if not loop.last %} +{% endif %} +{% endfor %} + ); + } +{% endfor %} +{% endif %} + + // ---------------------------------------------------------------- + // Linear outflow transitions + // ---------------------------------------------------------------- +{% for t in cfg.transitions if t.type == 'linear' %} + flows[Base::template get_flat_flow_index< + InfectionState::{{ t.from_state }}, + InfectionState::{{ t.to_state }}>(i)] = + (FP(1.0) / params.template get<{{ t.parameter }}>()[i]) * + y[idx_{{ t.from_state }}_i]; +{% endfor %} + +{% for t in cfg.transitions if t.type == 'custom' %} + // ---------------------------------------------------------------- + // TODO: Custom transition {{ t.from_state }} -> {{ t.to_state }} +{% if t.custom_formula %} + // Hint: {{ t.custom_formula }} +{% endif %} + // ---------------------------------------------------------------- + flows[Base::template get_flat_flow_index< + InfectionState::{{ t.from_state }}, + InfectionState::{{ t.to_state }}>(i)] = + /* YOUR EXPRESSION HERE */; +{% endfor %} + } + } + + /** + * serialize this. + * @see mio::serialize + */ + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("Model"); + obj.add_element("Parameters", this->parameters); + obj.add_element("Populations", this->populations); + } + + /** + * deserialize an object of this class. + * @see mio::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.create_object("Model"); + auto par = obj.expect_element("Parameters", Tag{}); + auto pop = obj.expect_element("Populations", Tag{}); + return apply( + io, + [](auto&& par_, auto&& pop_) { + return Model{pop_, par_}; + }, + par, pop); + } +}; + +} // namespace {{ cfg.meta.namespace }} +} // namespace mio + +#endif // {{ cfg.meta.guard_prefix }}_MODEL_H diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 new file mode 100644 index 0000000000..e7b86c2ca4 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/parameters_h.jinja2 @@ -0,0 +1,242 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: generated by memilio-modelgenerator +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef {{ cfg.meta.guard_prefix }}_PARAMETERS_H +#define {{ cfg.meta.guard_prefix }}_PARAMETERS_H + +#include "memilio/config.h" +#include "memilio/epidemiology/age_group.h" +{% if cfg.has_infection_transition %} +#include "memilio/epidemiology/uncertain_matrix.h" +{% endif %} +#include "memilio/utils/custom_index_array.h" +#include "memilio/utils/parameter_set.h" +#include "memilio/utils/uncertain_value.h" + +namespace mio +{ +namespace {{ cfg.meta.namespace }} +{ + +/************************************************************ + * Define Parameters of the {{ cfg.meta.name }} model + ************************************************************/ + +{% for param in cfg.parameters %} +/** + * @brief {{ param.description }} + */ +template +struct {{ param.name }} { +{% if param.per_age_group %} + using Type = CustomIndexArray, AgeGroup>; + static Type get_default(AgeGroup size) + { + return Type(size, {{ "%.1f"|format(param.default) }}); + } +{% else %} + using Type = UncertainValue; + static Type get_default(AgeGroup /*size*/) + { + return Type({{ "%.1f"|format(param.default) }}); + } +{% endif %} + static std::string name() + { + return "{{ param.name }}"; + } +}; + +{% endfor %} +{% if cfg.has_infection_transition %} +/** + * @brief The contact patterns within the society are modelled using a ContactMatrix. + */ +template +struct ContactPatterns { + using Type = UncertainContactMatrix; + static Type get_default(AgeGroup size) + { + return Type(1, static_cast((size_t)size)); + } + static std::string name() + { + return "ContactPatterns"; + } +}; + +{% endif %} +template +using ParametersBase = + ParameterSet< +{% for param in cfg.parameters %} + {{ param.name }}{% if not loop.last or cfg.has_infection_transition %},{% endif %} + +{% endfor %} +{% if cfg.has_infection_transition %} + ContactPatterns +{% endif %} + >; + +/** + * @brief Parameters of the {{ cfg.meta.name }} model. + */ +template +class Parameters : public ParametersBase +{ +public: + Parameters(AgeGroup num_agegroups) + : ParametersBase(num_agegroups) + , m_num_groups{num_agegroups} + { + } + + AgeGroup get_num_groups() const + { + return m_num_groups; + } + + /** + * @brief Checks whether all Parameters satisfy their constraints and corrects them if not. + * + * @return Returns true if one or more constraints were corrected, false otherwise. + */ + bool apply_constraints() + { +{% for param in cfg.parameters %} +{% if param.type == 'time' %} + const FP lower_bound_{{ param.name }} = {{ "%.17g"|format(param.bounds[0]) }}; +{% if param.bounds[1] is not none %} + const FP upper_bound_{{ param.name }} = {{ "%.17g"|format(param.bounds[1]) }}; +{% endif %} +{% endif %} +{% endfor %} + int corrected = false; + + for (auto i = AgeGroup(0); i < AgeGroup(m_num_groups); ++i) { +{% for param in cfg.parameters %} +{% if param.type == 'time' %} + if (this->template get<{{ param.name }}>()[i] < lower_bound_{{ param.name }}) { + log_warning( + "Constraint check: Parameter {{ param.name }} changed from {} to {}. Please note that " + "unreasonably small compartment stays lead to massively increased run time. Consider to cancel " + "and reset parameters.", + this->template get<{{ param.name }}>()[i], lower_bound_{{ param.name }}); + this->template get<{{ param.name }}>()[i] = lower_bound_{{ param.name }}; + corrected = true; + } +{% if param.bounds[1] is not none %} + if (this->template get<{{ param.name }}>()[i] > upper_bound_{{ param.name }}) { + log_warning("Constraint check: Parameter {{ param.name }} changed from {} to {}", + this->template get<{{ param.name }}>()[i], upper_bound_{{ param.name }}); + this->template get<{{ param.name }}>()[i] = upper_bound_{{ param.name }}; + corrected = true; + } +{% endif %} +{% elif param.type == 'probability' %} + if (this->template get<{{ param.name }}>()[i] < 0.0 || + this->template get<{{ param.name }}>()[i] > 1.0) { + log_warning("Constraint check: Parameter {{ param.name }} changed from {} to {} ", + this->template get<{{ param.name }}>()[i], 0.0); + this->template get<{{ param.name }}>()[i] = 0.0; + corrected = true; + } +{% endif %} +{% endfor %} + } + return corrected; + } + + /** + * @brief Checks whether all Parameters satisfy their constraints and logs an error if not. + * @return Returns true if a constraint is violated, otherwise false. + */ + bool check_constraints() const + { +{% for param in cfg.parameters %} +{% if param.type == 'time' %} + const FP lower_bound_{{ param.name }} = {{ "%.17g"|format(param.bounds[0]) }}; +{% if param.bounds[1] is not none %} + const FP upper_bound_{{ param.name }} = {{ "%.17g"|format(param.bounds[1]) }}; +{% endif %} +{% endif %} +{% endfor %} + + for (auto i = AgeGroup(0); i < m_num_groups; i++) { +{% for param in cfg.parameters %} +{% if param.type == 'time' %} + if (this->template get<{{ param.name }}>()[i] < lower_bound_{{ param.name }}) { + log_warning( + "Constraint check: Parameter {{ param.name }} {} smaller than {}. Please note that " + "unreasonably small compartment stays lead to massively increased run time. Consider to cancel " + "and reset parameters.", + this->template get<{{ param.name }}>()[i], lower_bound_{{ param.name }}); + return true; + } +{% if param.bounds[1] is not none %} + if (this->template get<{{ param.name }}>()[i] > upper_bound_{{ param.name }}) { + log_error("Constraint check: Parameter {{ param.name }} {} greater {}", + this->template get<{{ param.name }}>()[i], upper_bound_{{ param.name }}); + return true; + } +{% endif %} +{% elif param.type == 'probability' %} + if (this->template get<{{ param.name }}>()[i] < 0.0 || + this->template get<{{ param.name }}>()[i] > 1.0) { + log_error("Constraint check: Parameter {{ param.name }} {} smaller {} or greater {}", + this->template get<{{ param.name }}>()[i], 0.0, 1.0); + return true; + } +{% endif %} +{% endfor %} + } + return false; + } + +private: + Parameters(ParametersBase&& base) + : ParametersBase(std::move(base)) +{% if cfg.has_infection_transition %} + , m_num_groups(this->template get>().get_cont_freq_mat().get_num_groups()) +{% else %} + , m_num_groups(AgeGroup(1)) +{% endif %} + { + } + +public: + /** + * deserialize an object of this class. + * @see mio::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + BOOST_OUTCOME_TRY(auto&& base, ParametersBase::deserialize(io)); + return success(Parameters(std::move(base))); + } + +private: + AgeGroup m_num_groups; +}; + +} // namespace {{ cfg.meta.namespace }} +} // namespace mio + +#endif // {{ cfg.meta.guard_prefix }}_PARAMETERS_H diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/pybindings_cpp.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/pybindings_cpp.jinja2 new file mode 100644 index 0000000000..e29df138b1 --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/pybindings_cpp.jinja2 @@ -0,0 +1,117 @@ +/* +* Copyright (C) 2020-2026 MEmilio +* +* Authors: generated by memilio-modelgenerator +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +//Includes from pymio +#include "pybind_util.h" +#include "utils/index.h" +#include "utils/custom_index_array.h" +#include "utils/parameter_set.h" +#include "compartments/simulation.h" +#include "compartments/flow_simulation.h" +#include "compartments/compartmental_model.h" +#include "epidemiology/age_group.h" +#include "epidemiology/populations.h" +#include "data/analyze_result.h" + +//Includes from MEmilio +#include "{{ cfg.meta.prefix }}/model.h" +#include "{{ cfg.meta.prefix }}/infection_state.h" +#include "memilio/data/analyze_result.h" + +#include "pybind11/pybind11.h" + +namespace py = pybind11; + +namespace pymio +{ +//specialization of pretty_name +template <> +inline std::string pretty_name() +{ + return "InfectionState"; +} + +} // namespace pymio + +PYBIND11_MODULE(_simulation_{{ cfg.meta.namespace }}, m) +{ + pymio::bind_interpolate_result_methods(m); + + pymio::iterable_enum(m, "InfectionState") +{% for state in cfg.infection_states %} + .value("{{ state }}", mio::{{ cfg.meta.namespace }}::InfectionState::{{ state }}){% if not loop.last %} +{% endif %} + +{% endfor %} + ; + + pymio::bind_ParameterSet< + mio::{{ cfg.meta.namespace }}::ParametersBase, + pymio::EnablePickling::Required>(m, "ParametersBase"); + + pymio::bind_class< + mio::{{ cfg.meta.namespace }}::Parameters, + pymio::EnablePickling::Required, + mio::{{ cfg.meta.namespace }}::ParametersBase>(m, "Parameters", py::module_local{}) + .def(py::init()) + .def("check_constraints", &mio::{{ cfg.meta.namespace }}::Parameters::check_constraints); + + using Populations = + mio::Populations; + + pymio::bind_Population(m, "Populations", + mio::Tag::Populations>{}); + + pymio::bind_CompartmentalModel< + mio::{{ cfg.meta.namespace }}::InfectionState, + Populations, + mio::{{ cfg.meta.namespace }}::Parameters, + pymio::EnablePickling::Never>(m, "ModelBase"); + + pymio::bind_class< + mio::{{ cfg.meta.namespace }}::Model, + pymio::EnablePickling::Required, + mio::CompartmentalModel< + double, + mio::{{ cfg.meta.namespace }}::InfectionState, + Populations, + mio::{{ cfg.meta.namespace }}::Parameters>>(m, "Model") + .def(py::init(), py::arg("num_agegroups")); + + pymio::bind_Simulation< + mio::Simulation>>(m, "Simulation"); + + pymio::bind_Flow_Simulation< + mio::FlowSimulation>>(m, "FlowSimulation"); + + m.def("simulate", + &mio::simulate>, + "Simulates an ODE {{ cfg.meta.name }} model from t0 to tmax.", + py::arg("t0"), py::arg("tmax"), py::arg("dt"), + py::arg("model"), py::arg("integrator") = py::none()); + + m.def("simulate_flows", + &mio::simulate_flows>, + "Simulates an ODE {{ cfg.meta.name }} model with flows from t0 to tmax.", + py::arg("t0"), py::arg("tmax"), py::arg("dt"), + py::arg("model"), py::arg("integrator") = py::none()); + + m.attr("__version__") = "dev"; +} diff --git a/pycode/memilio-generation/memilio/modelgenerator/templates/simulation_py.jinja2 b/pycode/memilio-generation/memilio/modelgenerator/templates/simulation_py.jinja2 new file mode 100644 index 0000000000..29577784bc --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/templates/simulation_py.jinja2 @@ -0,0 +1,25 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: generated by memilio-modelgenerator +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Python bindings for MEmilio ODE {{ cfg.meta.name }} model. +""" + +from memilio.simulation._simulation_{{ cfg.meta.namespace }} import * diff --git a/pycode/memilio-generation/memilio/modelgenerator/validator.py b/pycode/memilio-generation/memilio/modelgenerator/validator.py new file mode 100644 index 0000000000..d8b2b3d37e --- /dev/null +++ b/pycode/memilio-generation/memilio/modelgenerator/validator.py @@ -0,0 +1,234 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +""" +Validation of a raw YAML dictionary before it is converted to `ModelConfig`. + +All errors are collected and raised together as a single +`ValidationError` so the user sees the full list at once. +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from .schema import ParameterType, TransitionType + + +class ValidationError(Exception): + """Raised when one or more validation errors are found.""" + + def __init__(self, errors: list[str]): + self.errors = errors + bullet_list = "\n".join(f" • {e}" for e in errors) + super().__init__(f"Model configuration is invalid:\n{bullet_list}") + + +class Validator: + """Validate raw model dictionaries loaded from YAML or TOML.""" + + @staticmethod + def validate(data: dict[str, Any]) -> None: + """ + Validate ``data``. + + :param data: Dictionary as returned by ``yaml.safe_load``. + :type data: dict[str, Any] + :raises ValidationError: If one or more validation errors are found. + """ + errors: list[str] = [] + + # model + model = data.get("model") + if not isinstance(model, dict): + errors.append("'model' section is missing or not a mapping.") + else: + for key in ("name", "namespace", "prefix"): + if not isinstance( + model.get(key), + str) or not model[key].strip(): + errors.append(f"'model.{key}' must be a non-empty string.") + + # infection_states + states = data.get("infection_states") + if not isinstance(states, list) or len(states) < 2: + errors.append( + "'infection_states' must be a list with at least 2 entries.") + states = [] + else: + for i, s in enumerate(states): + if not isinstance(s, str) or not s.strip(): + errors.append( + f"'infection_states[{i}]' must be a non-empty string.") + if len(states) != len(set(states)): + errors.append("'infection_states' contains duplicate entries.") + + state_set = set(states) + + # parameters + params = data.get("parameters") + if not isinstance(params, list) or len(params) == 0: + errors.append("'parameters' must be a non-empty list.") + params = [] + + param_names: list[str] = [] + for i, p in enumerate(params): + loc = f"parameters[{i}]" + if not isinstance(p, dict): + errors.append(f"'{loc}' must be a mapping.") + continue + + name = p.get("name") + if not isinstance(name, str) or not name.strip(): + errors.append(f"'{loc}.name' must be a non-empty string.") + else: + param_names.append(name) + + if not isinstance(p.get("description"), str): + errors.append(f"'{loc}.description' must be a string.") + + ptype = p.get("type") + if ptype not in ParameterType.ALL: + errors.append( + f"'{loc}.type' must be one of {ParameterType.ALL}, got {ptype!r}." + ) + + default = p.get("default") + if not isinstance(default, (int, float)): + errors.append(f"'{loc}.default' must be a number.") + + bounds = p.get("bounds") + if bounds is not None: + if not ( + isinstance(bounds, (list, tuple)) + and len(bounds) == 2 + and all(b is None or isinstance(b, (int, float)) for b in bounds) + ): + errors.append( + f"'{loc}.bounds' must be a list of two numbers or null, e.g. [0.0, 1.0]." + ) + else: + lower, upper = bounds + if lower is not None and upper is not None and lower > upper: + errors.append( + f"'{loc}.bounds' lower value must be <= upper value (got {lower} > {upper})." + ) + if ptype == ParameterType.TIME and upper is not None and upper < 0.1: + errors.append( + f"'{loc}.bounds' upper value for type 'time' must be >= 0.1." + ) + + if len(param_names) != len(set(param_names)): + errors.append("'parameters' contains duplicate 'name' entries.") + + param_name_set = set(param_names) + + # transitions + transitions = data.get("transitions") + if not isinstance(transitions, list) or len(transitions) == 0: + errors.append("'transitions' must be a non-empty list.") + transitions = [] + + for i, t in enumerate(transitions): + loc = f"transitions[{i}]" + if not isinstance(t, dict): + errors.append(f"'{loc}' must be a mapping.") + continue + + from_state = t.get("from") + to_state = t.get("to") + ttype = t.get("type") + + if from_state not in state_set: + errors.append( + f"'{loc}.from' references unknown state {from_state!r}." + ) + if to_state not in state_set: + errors.append( + f"'{loc}.to' references unknown state {to_state!r}." + ) + if from_state == to_state and from_state is not None: + errors.append( + f"'{loc}': 'from' and 'to' must differ (got {from_state!r})." + ) + + if ttype not in TransitionType.ALL: + errors.append( + f"'{loc}.type' must be one of {TransitionType.ALL}, got {ttype!r}." + ) + continue + + if ttype in (TransitionType.INFECTION, TransitionType.LINEAR): + param = t.get("parameter") + if param not in param_name_set: + errors.append( + f"'{loc}.parameter' references unknown parameter {param!r}." + ) + + if ttype == TransitionType.INFECTION: + has_singular = "infectious_state" in t + has_plural = "infectious_states" in t + if has_singular and has_plural: + errors.append( + f"'{loc}' must define only one of 'infectious_state' or 'infectious_states'." + ) + continue + + key = "infectious_states" if has_plural else "infectious_state" + inf_raw = t.get(key) + + if inf_raw is None: + errors.append( + f"'{loc}.{key}' must be provided for infection transitions." + ) + continue + + if isinstance(inf_raw, str): + inf_states = [inf_raw] + elif isinstance(inf_raw, list): + if len(inf_raw) == 0: + errors.append(f"'{loc}.{key}' must not be empty.") + continue + inf_states = [] + for j, s in enumerate(inf_raw): + if not isinstance(s, str) or not s.strip(): + errors.append( + f"'{loc}.{key}[{j}]' must be a non-empty string." + ) + else: + inf_states.append(s) + if len(inf_states) != len(set(inf_states)): + errors.append( + f"'{loc}.{key}' contains duplicate entries." + ) + else: + errors.append( + f"'{loc}.{key}' must be a string or a non-empty list of strings." + ) + continue + + for s in inf_states: + if s not in state_set: + errors.append( + f"'{loc}.{key}' references unknown state {s!r}." + ) + + if errors: + raise ValidationError(errors) diff --git a/pycode/memilio-generation/pyproject.toml b/pycode/memilio-generation/pyproject.toml index ea206d2557..a1402400cd 100644 --- a/pycode/memilio-generation/pyproject.toml +++ b/pycode/memilio-generation/pyproject.toml @@ -22,7 +22,10 @@ dependencies = [ "dataclasses", "dataclasses_json", "graphviz", - "importlib-resources>=1.1.0; python_version < '3.9'" + "importlib-resources>=1.1.0; python_version < '3.9'", + "jinja2>=3.0.0", + "pyyaml>=6.0", + "tomli>=1.1.0; python_version < '3.11'" ] [project.optional-dependencies] @@ -42,3 +45,4 @@ include-package-data = true [tool.setuptools.package-data] "memilio.generation" = ["tools/*.json", "tools/*.txt", "tools/README.md"] +"memilio.modelgenerator" = ["templates/*.jinja2"] diff --git a/pycode/memilio-generation/tests/test_modelgenerator.py b/pycode/memilio-generation/tests/test_modelgenerator.py new file mode 100644 index 0000000000..b85b41c57a --- /dev/null +++ b/pycode/memilio-generation/tests/test_modelgenerator.py @@ -0,0 +1,854 @@ +############################################################################# +# Copyright (C) 2020-2026 MEmilio +# +# Authors: Henrik Zunker +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import os +import tempfile +import unittest + +from memilio.modelgenerator import Generator +from memilio.modelgenerator.validator import ValidationError + +HERE = os.path.dirname(os.path.abspath(__file__)) +EXAMPLES_DIR = os.path.join(HERE, "..", "..", "examples", "modelgenerator") + +SEIR_YAML = os.path.join(EXAMPLES_DIR, "seir.yaml") +SEIRD_YAML = os.path.join(EXAMPLES_DIR, "seird.yaml") +SEIR_TOML = os.path.join(EXAMPLES_DIR, "seir.toml") + + +def _render(yaml_path: str) -> dict: + return Generator.from_yaml(yaml_path).render() + + +# Parsing +class TestParsing(unittest.TestCase): + + def test_seir_meta(self): + gen = Generator.from_yaml(SEIR_YAML) + cfg = gen._config + self.assertEqual(cfg.meta.name, "SEIR") + self.assertEqual(cfg.meta.namespace, "oseir") + self.assertEqual(cfg.meta.prefix, "ode_seir") + + def test_seir_states(self): + gen = Generator.from_yaml(SEIR_YAML) + self.assertEqual(gen._config.infection_states, + ["Susceptible", "Exposed", "Infected", "Recovered"]) + + def test_seir_parameters(self): + gen = Generator.from_yaml(SEIR_YAML) + names = [p.name for p in gen._config.parameters] + self.assertIn("TransmissionProbabilityOnContact", names) + self.assertIn("TimeExposed", names) + self.assertIn("TimeInfected", names) + + def test_seir_transitions(self): + gen = Generator.from_yaml(SEIR_YAML) + types = [t.type for t in gen._config.transitions] + self.assertIn("infection", types) + self.assertIn("linear", types) + + def test_seird_has_custom_transition(self): + gen = Generator.from_yaml(SEIRD_YAML) + custom = [t for t in gen._config.transitions if t.type == "custom"] + self.assertEqual(len(custom), 1) + self.assertEqual(custom[0].from_state, "Infected") + self.assertEqual(custom[0].to_state, "Dead") + + def test_has_infection_transition_flag(self): + gen = Generator.from_yaml(SEIR_YAML) + self.assertTrue(gen._config.has_infection_transition) + + def test_parameter_defaults(self): + gen = Generator.from_yaml(SEIR_YAML) + by_name = {p.name: p for p in gen._config.parameters} + self.assertAlmostEqual(by_name["TimeExposed"].default, 5.2) + self.assertAlmostEqual(by_name["TimeInfected"].default, 6.0) + self.assertAlmostEqual( + by_name["TransmissionProbabilityOnContact"].default, 1.0) + + def test_parameter_bounds(self): + gen = Generator.from_yaml(SEIR_YAML) + by_name = {p.name: p for p in gen._config.parameters} + prob = by_name["TransmissionProbabilityOnContact"] + self.assertEqual(prob.bounds, (0.0, 1.0)) + time_exp = by_name["TimeExposed"] + self.assertAlmostEqual(time_exp.bounds[0], 0.1) + self.assertIsNone(time_exp.bounds[1]) + + def test_time_bound_lower_floor_applied(self): + d = { + "model": {"name": "X", "namespace": "ox", "prefix": "ode_x"}, + "infection_states": ["S", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "time", "default": 1.0, "bounds": [0.01, None]} + ], + "transitions": [ + {"from": "S", "to": "I", "type": "linear", "parameter": "Rate"} + ], + } + gen = Generator.from_dict(d) + self.assertAlmostEqual(gen._config.parameters[0].bounds[0], 0.1) + + +# TOML loading +class TestTomlLoading(unittest.TestCase): + + def test_toml_parses_same_meta_as_yaml(self): + gen_yaml = Generator.from_yaml(SEIR_YAML) + gen_toml = Generator.from_toml(SEIR_TOML) + self.assertEqual(gen_toml._config.meta.name, + gen_yaml._config.meta.name) + self.assertEqual(gen_toml._config.meta.namespace, + gen_yaml._config.meta.namespace) + self.assertEqual(gen_toml._config.meta.prefix, + gen_yaml._config.meta.prefix) + + def test_toml_parses_same_states(self): + gen_yaml = Generator.from_yaml(SEIR_YAML) + gen_toml = Generator.from_toml(SEIR_TOML) + self.assertEqual(gen_toml._config.infection_states, + gen_yaml._config.infection_states) + + def test_toml_parses_same_parameters(self): + gen_yaml = Generator.from_yaml(SEIR_YAML) + gen_toml = Generator.from_toml(SEIR_TOML) + names_yaml = [p.name for p in gen_yaml._config.parameters] + names_toml = [p.name for p in gen_toml._config.parameters] + self.assertEqual(names_toml, names_yaml) + + def test_toml_parses_same_transitions(self): + gen_yaml = Generator.from_yaml(SEIR_YAML) + gen_toml = Generator.from_toml(SEIR_TOML) + types_yaml = [t.type for t in gen_yaml._config.transitions] + types_toml = [t.type for t in gen_toml._config.transitions] + self.assertEqual(types_toml, types_yaml) + + def test_toml_renders_identical_model_h(self): + files_yaml = Generator.from_yaml(SEIR_YAML).render() + files_toml = Generator.from_toml(SEIR_TOML).render() + self.assertEqual( + files_toml["cpp/models/ode_seir/model.h"], + files_yaml["cpp/models/ode_seir/model.h"]) + + def test_toml_renders_identical_infection_state_h(self): + files_yaml = Generator.from_yaml(SEIR_YAML).render() + files_toml = Generator.from_toml(SEIR_TOML).render() + self.assertEqual( + files_toml["cpp/models/ode_seir/infection_state.h"], + files_yaml["cpp/models/ode_seir/infection_state.h"]) + + +# infection_state.h template +class TestInfectionStateTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.content = self.files["cpp/models/ode_seir/infection_state.h"] + + def test_include_guard(self): + self.assertIn("#ifndef ODE_SEIR_INFECTIONSTATE_H", self.content) + self.assertIn("#define ODE_SEIR_INFECTIONSTATE_H", self.content) + self.assertIn("#endif // ODE_SEIR_INFECTIONSTATE_H", self.content) + + def test_namespace(self): + self.assertIn("namespace oseir", self.content) + + def test_all_states_present(self): + for state in [ + "Susceptible", "Exposed", "Infected", "Recovered", "Count"]: + self.assertIn(state, self.content) + + def test_enum_class(self): + self.assertIn("enum class InfectionState", self.content) + + +# parameters.h template +class TestParametersTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.content = self.files["cpp/models/ode_seir/parameters.h"] + + def test_include_guard(self): + self.assertIn("#ifndef ODE_SEIR_PARAMETERS_H", self.content) + + def test_parameter_structs(self): + for name in [ + "TransmissionProbabilityOnContact", "TimeExposed", "TimeInfected"]: + self.assertIn(f"struct {name}", self.content) + + def test_contact_patterns_added(self): + self.assertIn("struct ContactPatterns", self.content) + + def test_parameters_base(self): + self.assertIn("using ParametersBase =", self.content) + self.assertIn("ContactPatterns", self.content) + + def test_parameters_class(self): + self.assertIn( + "class Parameters : public ParametersBase", self.content) + self.assertIn("apply_constraints", self.content) + self.assertIn("check_constraints", self.content) + + def test_probability_constraint(self): + self.assertIn( + "TransmissionProbabilityOnContact>()[i] < 0.0", self.content) + + def test_time_constraint(self): + self.assertIn("lower_bound_TimeExposed", self.content) + self.assertIn("lower_bound_TimeInfected", self.content) + + def test_time_upper_bound_constraint(self): + d = { + "model": {"name": "SI", "namespace": "osi", "prefix": "ode_si"}, + "infection_states": ["S", "I"], + "parameters": [ + {"name": "Rate", "description": "d", "type": "time", + "default": 5.0, "bounds": [0.5, 7.5]} + ], + "transitions": [ + {"from": "S", "to": "I", "type": "linear", "parameter": "Rate"} + ], + } + content = Generator.from_dict( + d).render()["cpp/models/ode_si/parameters.h"] + self.assertIn("upper_bound_Rate", content) + self.assertIn("greater", content) + + def test_default_values_in_get_default(self): + # TimeExposed default = 5.2, TimeInfected = 6.0 + self.assertIn("5.2", self.content) + self.assertIn("6.0", self.content) + + def test_no_contact_patterns_without_infection(self): + # A model with only linear transitions must not get ContactPatterns + d = { + "model": {"name": "SI", "namespace": "osi", "prefix": "ode_si"}, + "infection_states": ["S", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "time", "default": 5.0} + ], + "transitions": [ + {"from": "S", "to": "I", "type": "linear", "parameter": "Rate"} + ], + } + content = Generator.from_dict(d).render()[ + "cpp/models/ode_si/parameters.h"] + self.assertNotIn("ContactPatterns", content) + + +# model.h template +class TestModelTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.content = self.files["cpp/models/ode_seir/model.h"] + + def test_include_guard(self): + self.assertIn("#ifndef ODE_SEIR_MODEL_H", self.content) + + def test_flows_typelist(self): + self.assertIn("using Flows = TypeList<", self.content) + self.assertIn( + "Flow", + self.content) + self.assertIn( + "Flow", + self.content) + + def test_get_flows_method(self): + self.assertIn("void get_flows(", self.content) + + def test_infection_flow_uses_contact_matrix(self): + self.assertIn("ContactPatterns", self.content) + self.assertIn("get_cont_freq_mat", self.content) + + def test_infection_flow_supports_multiple_infectious_states(self): + d = { + "model": {"name": "SEIIR", "namespace": "oseiir", "prefix": "ode_seiir"}, + "infection_states": ["S", "E", "I1", "I2", "R"], + "parameters": [ + {"name": "Beta", "description": "b", + "type": "probability", "default": 0.5}, + {"name": "TimeExposed", "description": "t", + "type": "time", "default": 4.0}, + ], + "transitions": [ + {"from": "S", "to": "E", "type": "infection", "parameter": "Beta", + "infectious_state": ["I1", "I2"]}, + {"from": "E", "to": "R", "type": "linear", + "parameter": "TimeExposed"}, + ], + } + content = Generator.from_dict( + d).render()["cpp/models/ode_seiir/model.h"] + self.assertIn("pop[idx_I1_j] +", content) + self.assertIn("pop[idx_I2_j]", content) + + def test_infection_flow_supports_infectious_states_key(self): + d = { + "model": {"name": "SEIIR", "namespace": "oseiir", "prefix": "ode_seiir"}, + "infection_states": ["S", "E", "I1", "I2", "R"], + "parameters": [ + {"name": "Beta", "description": "b", + "type": "probability", "default": 0.5}, + {"name": "TimeExposed", "description": "t", + "type": "time", "default": 4.0}, + ], + "transitions": [ + {"from": "S", "to": "E", "type": "infection", "parameter": "Beta", + "infectious_states": ["I1", "I2"]}, + {"from": "E", "to": "R", "type": "linear", + "parameter": "TimeExposed"}, + ], + } + content = Generator.from_dict( + d).render()["cpp/models/ode_seiir/model.h"] + self.assertIn("pop[idx_I1_j] +", content) + self.assertIn("pop[idx_I2_j]", content) + + def test_linear_flows(self): + self.assertIn("TimeExposed>()[i]", self.content) + self.assertIn("TimeInfected>()[i]", self.content) + + def test_serialize_deserialize(self): + self.assertIn("void serialize(", self.content) + self.assertIn("static IOResult deserialize(", self.content) + + def test_index_variables_for_all_states(self): + for state in ["Susceptible", "Exposed", "Infected", "Recovered"]: + self.assertIn(f"idx_{state}_i", self.content) + + def test_seird_custom_transition_todo(self): + content = Generator.from_yaml(SEIRD_YAML).render()[ + "cpp/models/ode_seird/model.h"] + self.assertIn("TODO", content) + self.assertIn("YOUR EXPRESSION HERE", content) + + def test_seird_custom_formula_hint(self): + content = Generator.from_yaml(SEIRD_YAML).render()[ + "cpp/models/ode_seird/model.h"] + self.assertIn("DeathRate[i] * y[idx_Infected_i]", content) + + +# pybindings.cpp template +class TestPybindingsTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + key = "pycode/memilio-simulation/memilio/simulation/bindings/models/ode_seir.cpp" + self.content = self.files[key] + + def test_module_name(self): + self.assertIn("PYBIND11_MODULE(_simulation_oseir, m)", self.content) + + def test_enum_values(self): + for state in ["Susceptible", "Exposed", "Infected", "Recovered"]: + self.assertIn(f'.value("{state}"', self.content) + + def test_simulate_functions(self): + self.assertIn('m.def("simulate"', self.content) + self.assertIn('m.def("simulate_flows"', self.content) + + def test_model_init(self): + self.assertIn("py::init()", self.content) + + +# CMakeLists.txt template +class TestCMakeTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.content = self.files["cpp/models/ode_seir/CMakeLists.txt"] + + def test_add_library(self): + self.assertIn("add_library(ode_seir", self.content) + + def test_source_files(self): + for src in ["infection_state.h", "parameters.h", "model.h", + "model.cpp"]: + self.assertIn(src, self.content) + + def test_link_libraries(self): + self.assertIn( + "target_link_libraries(ode_seir PUBLIC memilio)", self.content) + + +# Python example template +class TestExampleTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.key = "pycode/examples/simulation/ode_seir_simple.py" + self.content = self.files[self.key] + + def test_example_key_in_render(self): + self.assertIn(self.key, self.files) + + def test_imports_numpy(self): + self.assertIn("import numpy as np", self.content) + + def test_imports_agegroup(self): + self.assertIn("from memilio.simulation import AgeGroup", self.content) + + def test_imports_correct_module(self): + self.assertIn( + "from memilio.simulation.oseir import", self.content) + + def test_imports_damping_for_infection_model(self): + self.assertIn("from memilio.simulation import Damping", self.content) + + def test_simulate_call(self): + self.assertIn("simulate(t0, tmax, dt, model)", self.content) + + def test_interpolate_call(self): + self.assertIn("interpolate_simulation_result(result)", self.content) + + def test_default_parameter_values(self): + # TransmissionProbabilityOnContact default = 1.0 + self.assertIn( + "model.parameters.TransmissionProbabilityOnContact[A0] = 1.0", + self.content) + # TimeExposed default = 5.2 + self.assertIn( + "model.parameters.TimeExposed[A0] = 5.2", self.content) + # TimeInfected default = 6.0 + self.assertIn( + "model.parameters.TimeInfected[A0] = 6.0", self.content) + + def test_contact_patterns_setup(self): + self.assertIn("cont_freq_mat[0].baseline", self.content) + self.assertIn("cont_freq_mat[0].minimum", self.content) + + def test_initial_conditions(self): + self.assertIn("State.Exposed", self.content) + self.assertIn("set_difference_from_total", self.content) + self.assertIn("State.Susceptible", self.content) + + def test_print_table(self): + self.assertIn("get_num_time_points", self.content) + self.assertIn("get_time", self.content) + self.assertIn("get_value", self.content) + + def test_run_simulation_function(self): + self.assertIn("def run_simulation(", self.content) + + def test_main_guard(self): + self.assertIn('if __name__ == "__main__":', self.content) + + def test_tmax_10_days_default(self): + self.assertIn("tmax=10.0", self.content) + + def test_no_damping_for_linear_only_model(self): + d = { + "model": {"name": "SI", "namespace": "osi", "prefix": "ode_si"}, + "infection_states": ["S", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "time", "default": 5.0} + ], + "transitions": [ + {"from": "S", "to": "I", "type": "linear", "parameter": "Rate"} + ], + } + content = Generator.from_dict(d).render()[ + "pycode/examples/simulation/ode_si_simple.py"] + self.assertNotIn("Damping", content) + self.assertNotIn("ContactPatterns", content) + + def test_seird_example_uses_second_state(self): + content = Generator.from_yaml(SEIRD_YAML).render()[ + "pycode/examples/simulation/ode_seird_simple.py"] + # Second state is Exposed, initial conditions should seed it + self.assertIn("State.Exposed", content) + self.assertIn("State.Susceptible", content) + + +# simulation_py +class TestSimulationPyTemplate(unittest.TestCase): + + def setUp(self): + self.files = _render(SEIR_YAML) + self.key = "pycode/memilio-simulation/memilio/simulation/oseir.py" + self.content = self.files[self.key] + + def test_key_in_render(self): + self.assertIn(self.key, self.files) + + def test_imports_compiled_module(self): + self.assertIn( + "from memilio.simulation._simulation_oseir import *", self.content) + + def test_namespace_in_key(self): + # namespace drives the filename + files = Generator.from_yaml(SEIRD_YAML).render() + self.assertIn( + "pycode/memilio-simulation/memilio/simulation/oseird.py", files) + + def test_seird_imports_correct_module(self): + files = Generator.from_yaml(SEIRD_YAML).render() + content = files["pycode/memilio-simulation/memilio/simulation/oseird.py"] + self.assertIn( + "from memilio.simulation._simulation_oseird import *", content) + + +# Validation +class TestValidation(unittest.TestCase): + + def _base(self): + return { + "model": {"name": "X", "namespace": "ox", "prefix": "ode_x"}, + "infection_states": ["S", "I"], + "parameters": [ + {"name": "Rate", "description": "d", + "type": "time", "default": 1.0} + ], + "transitions": [ + {"from": "S", "to": "I", "type": "linear", "parameter": "Rate"} + ], + } + + def test_missing_model_section(self): + d = self._base() + del d["model"] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_unknown_state_in_transition(self): + d = self._base() + d["transitions"][0]["from"] = "X_unknown" + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_unknown_parameter_in_transition(self): + d = self._base() + d["transitions"][0]["parameter"] = "NoSuchParam" + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_invalid_transition_type(self): + d = self._base() + d["transitions"][0]["type"] = "magic" + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_duplicate_states(self): + d = self._base() + d["infection_states"] = ["S", "S"] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_self_loop_transition(self): + d = self._base() + d["transitions"][0]["to"] = "S" + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_too_few_states(self): + d = self._base() + d["infection_states"] = ["S"] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_missing_infectious_state_for_infection_transition(self): + d = self._base() + d["transitions"] = [ + {"from": "S", "to": "I", "type": "infection", + "parameter": "Rate", "infectious_state": "Unknown"} + ] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_validation_error_lists_all_errors(self): + d = self._base() + del d["model"] + d["infection_states"] = ["S"] + try: + Generator.from_dict(d) + self.fail("Expected ValidationError") + except ValidationError as exc: + self.assertGreater(len(exc.errors), 1) + + def test_description_must_be_string(self): + d = self._base() + d["parameters"][0]["description"] = 42 + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_duplicate_parameter_names(self): + d = self._base() + d["parameters"].append( + {"name": "Rate", "description": "dup", "type": "time", "default": 2.0} + ) + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_infectious_state_list_accepts_multiple_states(self): + d = self._base() + d["infection_states"] = ["S", "I1", "I2"] + d["transitions"] = [ + {"from": "S", "to": "I1", "type": "infection", + "parameter": "Rate", "infectious_state": ["I1", "I2"]} + ] + Generator.from_dict(d) + + def test_infectious_states_key_accepts_multiple_states(self): + d = self._base() + d["infection_states"] = ["S", "I1", "I2"] + d["transitions"] = [ + {"from": "S", "to": "I1", "type": "infection", + "parameter": "Rate", "infectious_states": ["I1", "I2"]} + ] + Generator.from_dict(d) + + def test_empty_infectious_state_list_rejected(self): + d = self._base() + d["transitions"] = [ + {"from": "S", "to": "I", "type": "infection", + "parameter": "Rate", "infectious_state": []} + ] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_conflicting_infectious_state_keys_rejected(self): + d = self._base() + d["transitions"] = [ + {"from": "S", "to": "I", "type": "infection", "parameter": "Rate", + "infectious_state": "I", "infectious_states": ["I"]} + ] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_invalid_bounds_order_rejected(self): + d = self._base() + d["parameters"][0]["bounds"] = [2.0, 1.0] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + def test_time_upper_bound_below_floor_rejected(self): + d = self._base() + d["parameters"][0]["bounds"] = [None, 0.05] + with self.assertRaises(ValidationError): + Generator.from_dict(d) + + +# CMakeLists patching +_CPP_CMAKE_STUB = """\ +if(MEMILIO_BUILD_MODELS) + add_subdirectory(models/ode_sir) + add_subdirectory(models/ode_seir) + add_subdirectory(models/ode_mseirs4) +endif() +""" + +_SIM_CMAKE_STUB = """\ +add_pymio_module(_simulation_oseir + LINKED_LIBRARIES memilio ode_seir + SOURCES memilio/simulation/bindings/models/oseir.cpp +) + +# install all shared memilio libraries, which were given as "LINKED_LIBRARIES" to add_pymio_module +list(REMOVE_DUPLICATES PYMIO_MEMILIO_LIBS_LIST) +""" + +_SIM_INIT_STUB = """\ +from memilio.simulation._simulation import * + + +def __getattr__(attr): + if attr == "oseir": + import memilio.simulation.oseir as oseir + return oseir + raise AttributeError("module {!r} has no attribute {!r}".format(__name__, attr)) +""" + + +class TestCMakePatching(unittest.TestCase): + + def _make_repo(self, cpp_cmake=_CPP_CMAKE_STUB, sim_cmake=_SIM_CMAKE_STUB, + sim_init=_SIM_INIT_STUB): + from pathlib import Path + tmp = tempfile.mkdtemp() + cpp_dir = os.path.join(tmp, "cpp") + sim_dir = os.path.join(tmp, "pycode", "memilio-simulation") + sim_pkg_dir = os.path.join(sim_dir, "memilio", "simulation") + os.makedirs(cpp_dir) + os.makedirs(sim_pkg_dir) + with open(os.path.join(cpp_dir, "CMakeLists.txt"), "w") as f: + f.write(cpp_cmake) + with open(os.path.join(sim_dir, "CMakeLists.txt"), "w") as f: + f.write(sim_cmake) + with open(os.path.join(sim_pkg_dir, "__init__.py"), "w") as f: + f.write(sim_init) + return tmp + + def _gen(self): + return Generator.from_yaml(SEIR_YAML) + + def test_cpp_cmake_gets_patched(self): + from pathlib import Path + d = { + "model": {"name": "SIR", "namespace": "osir_new", "prefix": "ode_sir_new"}, + "infection_states": ["Susceptible", "Infected", "Recovered"], + "parameters": [ + {"name": "TransmissionRate", "description": "rate", + "type": "probability", "default": 0.3}, + {"name": "RecoveryTime", "description": "time", + "type": "time", "default": 7.0}, + ], + "transitions": [ + {"from": "Susceptible", "to": "Infected", "type": "infection", + "parameter": "TransmissionRate", "infectious_state": "Infected"}, + {"from": "Infected", "to": "Recovered", + "type": "linear", "parameter": "RecoveryTime"}, + ], + } + gen = Generator.from_dict(d) + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + patched = patches[gen._CPP_CMAKE] + self.assertIsNotNone(patched) + self.assertIn("add_subdirectory(models/ode_sir_new)", patched) + self.assertIn("add_subdirectory(models/ode_seir)", patched) + + def test_cpp_cmake_no_duplicate(self): + from pathlib import Path + gen = self._gen() + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + self.assertIsNone(patches[gen._CPP_CMAKE]) + + def test_sim_cmake_gets_patched(self): + from pathlib import Path + d = { + "model": {"name": "SIR", "namespace": "osir_new", "prefix": "ode_sir_new"}, + "infection_states": ["Susceptible", "Infected", "Recovered"], + "parameters": [ + {"name": "TransmissionRate", "description": "rate", + "type": "probability", "default": 0.3}, + {"name": "RecoveryTime", "description": "time", + "type": "time", "default": 7.0}, + ], + "transitions": [ + {"from": "Susceptible", "to": "Infected", "type": "infection", + "parameter": "TransmissionRate", "infectious_state": "Infected"}, + {"from": "Infected", "to": "Recovered", + "type": "linear", "parameter": "RecoveryTime"}, + ], + } + gen = Generator.from_dict(d) + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + patched = patches[gen._SIM_CMAKE] + self.assertIsNotNone(patched) + self.assertIn("add_pymio_module(_simulation_osir_new", patched) + self.assertIn("LINKED_LIBRARIES memilio ode_sir_new", patched) + self.assertIn( + "SOURCES memilio/simulation/bindings/models/ode_sir_new.cpp", + patched) + self.assertIn("_simulation_oseir", patched) + self.assertIn("list(REMOVE_DUPLICATES", patched) + + def test_sim_cmake_no_duplicate(self): + from pathlib import Path + gen = self._gen() + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + self.assertIsNone(patches[gen._SIM_CMAKE]) + + def test_write_creates_all_files(self): + from pathlib import Path + gen = self._gen() + tmp = self._make_repo() + gen.write(tmp) + prefix = gen._config.meta.prefix + expected = [ + f"cpp/models/{prefix}/infection_state.h", + f"cpp/models/{prefix}/parameters.h", + f"cpp/models/{prefix}/model.h", + f"cpp/models/{prefix}/model.cpp", + f"cpp/models/{prefix}/CMakeLists.txt", + ] + for rel in expected: + self.assertTrue( + (Path(tmp) / rel).exists(), f"Missing: {rel}") + + def test_write_raises_if_model_dir_exists(self): + from pathlib import Path + gen = self._gen() + tmp = self._make_repo() + # First write succeeds + gen.write(tmp) + # Second write without overwrite=True must fail + with self.assertRaises(FileExistsError): + gen.write(tmp) + + def test_write_overwrite_flag_allows_second_write(self): + from pathlib import Path + gen = self._gen() + tmp = self._make_repo() + gen.write(tmp) + # Should not raise + gen.write(tmp, overwrite=True) + + def test_sim_init_gets_patched(self): + from pathlib import Path + d = { + "model": {"name": "SIR", "namespace": "osir_new", "prefix": "ode_sir_new"}, + "infection_states": ["Susceptible", "Infected", "Recovered"], + "parameters": [ + {"name": "TransmissionRate", "description": "rate", + "type": "probability", "default": 0.3}, + {"name": "RecoveryTime", "description": "time", + "type": "time", "default": 7.0}, + ], + "transitions": [ + {"from": "Susceptible", "to": "Infected", "type": "infection", + "parameter": "TransmissionRate", "infectious_state": "Infected"}, + {"from": "Infected", "to": "Recovered", + "type": "linear", "parameter": "RecoveryTime"}, + ], + } + gen = Generator.from_dict(d) + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + patched = patches[gen._SIM_INIT] + self.assertIsNotNone(patched) + self.assertIn('attr == "osir_new"', patched) + self.assertIn( + "import memilio.simulation.osir_new as osir_new", patched) + # original entry must still be there + self.assertIn('attr == "oseir"', patched) + # raise AttributeError must still be there + self.assertIn("raise AttributeError", patched) + + def test_sim_init_no_duplicate(self): + from pathlib import Path + gen = self._gen() # namespace = oseir, already in stub + tmp = self._make_repo() + patches = gen.render_patches(Path(tmp)) + self.assertIsNone(patches[gen._SIM_INIT]) + + +if __name__ == "__main__": + unittest.main()