# Implementing an ODE model in Pymob In this tutorial, we will implement a simple ODE model, create simulation results and infer an unknown parameter from artificially generated data. It is recommended to work through this notebook after the introductiory tutorial where something very similar is done for a linear regression model. After setting up the simulation manually (Chapter 1), we will save our settings and create a new simulation from those settings (Chapter 2). # Chapter 1: Setting up the model 👩💻 👉 Let's begin with setting up a Pymob simulation for an ODE model. This will follow roughly the same procedure as the introductory tutorial. We do, however, need to make some tweaks to allow for the needs of an ODE model. ```python # First, import the necessary python packages import numpy as np import matplotlib.pyplot as plt import xarray as xr from scipy.integrate import solve_ivp # Import the pymob modules from pymob.simulation import SimulationBase from pymob.solvers.diffrax import JaxSolver from pymob.sim.config import Param, DataVariable ``` ## 1.1 Creating the `sim` object 🧩 👉 As an example for a relatively simple ODE model, we will use the well-known **Lotka-Volterra model** describing a predator-prey relationship. 👉 The equations for this model look like this ($X$ and $Y$ denote prey and predator, respectively): $\frac{dX}{dt} = \alpha X - \beta X Y$ $\frac{dY}{dt} = \gamma X Y - \delta Y$ $\newline \alpha, \beta, \gamma, \delta > 0$ 👉 In the following cell, we will define our model. To work with our solver (we will later use {class}`pymob.solvers.diffrax.JaxSolver` which calls `diffrax.diffeqsolve`), our Python function needs to have a signature of the form `fun(t, y, *args)` where `t` represents the current time within the system, `y` represents the current system state and `*args` is a placeholder for all model parameters. 👉 Note that the argument `t` is not used inside the function as the derivatives generated by the Lotka Volterra model are independent from time. It still needs to be included in the signature to satisfy the needs of the solver. ```python def lotkavolterra(t, y, alpha, beta, gamma, delta): X, Y = y dXdt = alpha * X - beta * X * Y dYdt = gamma * X * Y - delta * Y return dXdt, dYdt ``` 👉 We can then create our simulation object and assign the model and the solver to it: ```python # Initialize the simulation object sim = SimulationBase() # Configure the case study sim.config.case_study.name = "ODEtutorial" sim.config.case_study.scenario = "lotkavolterra" # Add the model to the simulation sim.model = lotkavolterra # Define a solver sim.solver = JaxSolver ``` /home/docs/checkouts/readthedocs.org/user_builds/pymob/envs/latest/lib/python3.11/site-packages/pymob/sim/config/base.py:397: UserWarning: Case study 'unnamed_case_study' could not be imported. Install the case study with `pip install unnamed_case_study`. warnings.warn( ## 1.2 Generating artificial data 📈 👉 Now we generate some artificial data that we will later use as our **observations**. To do this, we generate a time series of the Lotka-Volterra model with parameters $\alpha = 0.7, \beta = 0.1, \gamma = 0.1, \delta = 0.9$ from the initial condition $X = 10, Y = 5$ using `solve_ivp` (we could also use `diffrax.diffeqsolve` here, that would make no difference). This is done for 101 steps with $\Delta t = 0.5$. 👉 We then add some noise to the data and make sure that predator and prey abundances in our data are always positive as negative abundances would never be measured in reality. 👉 After running the code, you can take a look at our artificial data and recognize the characteristic periodic oscillations produced by the Lotka-Volterra model. ```python # Generate Lotka Volterra time series sol = solve_ivp(lotkavolterra, (0, 50), np.array([10,5]), "LSODA", np.linspace(0,50,101), args=[0.7,0.1,0.1,0.9]) # Add "random" noise (example is made reproducible by setting a fixed seed) rng = np.random.default_rng(seed=1) noise = rng.normal(0, 0.5, (2,101)) y_obs = sol.y + noise y_obs = np.greater(y_obs, np.zeros(y_obs.shape)) * y_obs # Save the evaluated time points t = sol.t # Plot the generated data fig, ax = plt.subplots(figsize=(5, 4)) ax.plot(t, y_obs.transpose(), label='Datapoints') ax.set(xlabel='t [-]', ylabel='y_obs [-]', title ='Artificial Data') plt.tight_layout() ```  ## 1.3 Adding data to the `sim` object 🤝 👉 Let's prepare our observations. As seen in the introductory tutorial, Pymob uses `xArray` datasets. Because our model has two state variables, the dataset containing our artificial data also needs to have two data variables. It also needs to include the time points we generated the data for as a coordinate axis. This can be achieved like this (or probably in an easier way): ```python # Create an xArray dataset containing the artificial data data_obs_1 = xr.DataArray(y_obs[0], coords={"time": t}).to_dataset(name="prey") data_obs_2 = xr.DataArray(y_obs[1], coords={"time": t}).to_dataset(name="predator") data_obs = xr.merge([data_obs_1, data_obs_2]) # Look at the structure of the generated datatset data_obs ```
<xarray.Dataset>
Dimensions: (time: 101)
Coordinates:
* time (time) float64 0.0 0.5 1.0 1.5 2.0 ... 48.0 48.5 49.0 49.5 50.0
Data variables:
prey (time) float64 10.17 11.36 11.85 11.33 ... 11.08 11.16 12.37 11.56
predator (time) float64 5.431 5.33 6.397 7.604 ... 5.544 5.436 7.871 9.127