Tuto2 Tuto1 Paper

AI-driven Automated Discovery Tools Reveal Diverse Behavioral Competencies of Biological Networks

Authors Affiliation Published
Mayalen Etcheverry INRIA, Flowers team, Poietis September, 2023
Clément Moulin-Frier INRIA, Flowers team
Pierre-Yves Oudeyer INRIA, Flowers team
Michael Levin The Levin Lab, Tufts University Reproduce in Notebook

Introduction

TL;DR

This tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks.
It is intented to walk you through the set of tools we use to (i) automatically explore the space of input stimulis of a given biological network in order to discover a diversity of behaviors in these systems (ii) analyze the robustness of the discovered abilities in order to infer the network's navigation competencies in the transcriptional space of gene activation (iii) explore possible reuses of the constructed "behavioral catalog" in a biomedical context

📝 How to follow this tutorial

💻 AutoDiscJax

Throughout this tutorial, we will be using the AutoDiscJax library, a library built on top of jax and equinox to facilitate automated experimentation and simulation of biological network pathways.

AutoDiscJax follows two main design principles: 1) Everything is a module, where a module is simply a parametrized function that takes inputs and returns outputs (and log_data). All autodiscjax modules adx.Module are implemented as equinox modules eqx.Module, which essentially allows to represent the function as a callable PyTree (and hence to be compatible with jax transformations) while keeping an intuitive API for model building (python class with a _call_ method). The only add-on with respect to equinox is that when instantiating a adx.Module, the user must specify the module's outputs PyTree structure, shape and dtype. 2) An experiment pipeline defines (i) how modules interact sequentially and exchange information, and (ii) what information should be collected and saved in the experiment history.

AutoDiscJax provides a handful of already-implement modules and pipelines to 1) Simulate biological networks while intervening on them according to our needs (see Part 1) 2) Automatically organize experimentation in those systems, by implementing a variety of exploration approaches such as random, optimization-driven and curiosity-driven search (see Part 2 and Part 3) 3) Analyze the discoveries of the exploration method, for instance by testing their robustness to various perturbations (see Part 4)

Finally, AutoDiscJax takes advantage of JAX mains features (just-in-time compilation, automatic vectorization and automatic differentation) which are especially advantageous for parallel experimentation and computational speedups, as well as gradient-based optimization.

Part 1: Numerical simulation (with interventions) of a GRN model

Ordinary differential equations (ODE) models are widely used to represent the behavior of complex biological processes. These ODE models are often experimentally determined as well as curated by biologists, e.g. combining experimental and data-drive methods to determine whether and how pairs of biomolecules interact. The resulting mathematical models are usually stored and exchanged using the Systems Biology Markup Language (SBML) language. Thanks to community efforts, large collections of published ODE models have been made publicly available on online databases, such as the BioModels website.

For this tutorial, we will be studying the gene regulatory network (GRN) model that describes the influence of RKIP on the ERK Signaling Pathway, that is described by this Cho et al.'s paper, and that is hosted on the BioModels database with the BIOMD0000000647 identifier and described by the following reaction graph: Drawing

Whereas the SBML file provides information about the different species, parameters and reactions involved in this model, we must use our own tools for carrying out simulation studies.

In the first part of this tutorial, we will be showing how to use the SBMLtoODEjax library to automatically retrieve and parse the SBML file into a python file, and the AutoDiscJax library to easily simulate the model and manipulate it based on our needs.

biomodel_idx = 647
observed_node_names = ["ERK", "RKIPP_RP"]

SBML to python conversion

Let's first use the SBMLtoODEjax library to download the SBML file and generate the corresponding python class that will allow us to simulate the model's dynamics.

out_model_sbml_filepath = f"data/biomodel_{biomodel_idx}.xml"
out_model_jax_filepath = f"data/biomodel_{biomodel_idx}.py"

# Donwload the SBML file
if not os.path.exists(out_model_sbml_filepath):
    model_xml_body = sbmltoodejax.biomodels_api.get_content_for_model(biomodel_idx)
    with open(out_model_sbml_filepath, 'w') as f:
        f.write(model_xml_body)

# Generation of the python class from the SBML file
if not os.path.exists(out_model_jax_filepath):
    model_data = sbmltoodejax.parse.ParseSBMLFile(out_model_sbml_filepath)
    sbmltoodejax.modulegeneration.GenerateModel(model_data, out_model_jax_filepath)

At this point, you should have created a biomodel_647.py file that contains several variables and functional modules as follows:

The variables are instantiated as such as for many models the kinematic parameters are all constant (w$=\emptyset$) but for others, there are some kinematic parameters w(t) that evolve in time in addition to the constant ones c.

For those who are interested to undersant more about the specifications of SBML files and SBMLtoODEjax files, we refer to the SBMLtoODEjax documentation: https://developmentalsystems.org/sbmltoodejax/design_principles.html

System Rollout module

Ok, let's now create our first module: the System Rollout. In short, the system rollout can be seen as a wrapper of the previously-created smbltoodejax's ModelRollout module, but this time allowing to apply all sorts of interventions and/or perturbations during the rollout simulation, and without having to modify the rollout codebase. We'll see how to do that in short, but first let's create our system rollout module and see how a simulation looks like.

When instancing the ModelStep module, we should specify the desired time step $\Delta t$ and ODE solver parameters $(atol, rtol, mxstep)$

# System Rollout Config
system_rollout_config = Dict()
system_rollout_config.system_type = "grn"
system_rollout_config.model_filepath = out_model_jax_filepath # path of model class that we just created using sbmltoodejax
system_rollout_config.atol = 1e-6 # parameters for the ODE solver 
system_rollout_config.rtol = 1e-12
system_rollout_config.mxstep = 1000
system_rollout_config.deltaT = 0.1 # the ODE solver will compute values every 0.1 second
system_rollout_config.n_secs = 1000 # number of a seconds of one rollout in the system 
system_rollout_config.n_system_steps = int(system_rollout_config.n_secs/system_rollout_config.deltaT) # total number of steps returned after a rollout

# Create the module
system_rollout = create_system_rollout_module(system_rollout_config)

# Get observed node ids
observed_node_ids = [system_rollout.grn_step.y_indexes[observed_node_names[0]], system_rollout.grn_step.y_indexes[observed_node_names[1]]]

Let's now simulate the module. By default, the rollout starts with the initial concentrations as given in the SBML file.

key, subkey = jrandom.split(key)
default_system_outputs, log_data = system_rollout(subkey) # all autodiscjax modules takes the state of the "random seed" as input (needed e.g. for operating random sampling operations within the module) 

Here is what we obtain when simulating the network with the default initial condtions (see original Figure 5 of Cho el al's paper):

Figure 1: Simulation results of the mathematical modeling for default initial condition.

We can also observe the resulting trajectories in phase space, also called transcriptional phase for gene regulatory networks. Let's for instance say that we are interested in observing the activation response of two specific nodes: ERK and RKIPP_RP. By plotting their trajectory in the transcriptional space, we can see that they start in point A(0,0) and navigate until they reach a steady state in point B(0.036,0.055).

Figure 2: Simulation results for default initial condition. Trajectory of nodes ERK and RKIPP_RP is shown in transcriptional space.

Applying interventions

We would now like to do several interventions on the system rollout to see how this influence the trajectory of the network in transcriptional space. The system rollout module allows us to do it by specifying an intervention function $(y^-, y, w^-, w, c^-, c, t^-, t) \mapsto (y, w, c, t)$ that is called at everystep during a rollout, where $(y^-,w^-,c^-,t^-)$ represent the variables values before the step (t-1) and $(y,w,c,t)$ represent the variable values that would be returned by the ModelStep function without intervention (and that are overwrited by the intervention function).

Several intervention functions are already provided in AutoDiscJax. In particular, the PiecewiseIntervention module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints. The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention or by adding to the variable's current value PiecewiseAddConstantIntervention).

Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X

Intervention 1: changing the initial species amount

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_values = [[0.03], [0.02]]

intervention_fn_1 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_1 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_1.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_1, intervention_params=intervention_params_1)
Figure 3: Simulation results for modified initial condition A'. Trajectory is shown for nodes ERK and RKIPP_RP in transcriptional space.

👉 Interestingly we can see that despite enforcing the network to start at another point A', the network still converges to the same point B. Moreover, instead of directly going from A' to B, the network seems to do a "detour" to reach back its initial trajectory, and then follow again the trajectory succesfully until point B.

Intervention 2: clamping the species amount to specific values

# Create the intervention
controlled_intervals = [[0, 10], [400, 410]]
controlled_node_names = ["MEKPP"]
controlled_node_values = [[2.5, 1.0]]

intervention_fn_2 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_2 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_2.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_2, intervention_params=intervention_params_2)
Figure 4: Simulation results for default initial condition A, with clamping of node MEKPP. (left) Evolution of RKIP (with clamp interventions) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

👉 Here, the clamping of MEKPP seems to have some effect on ERK but not on RKIPP_RP. Indeed, after the initial clamping of MEKPP to 2.5 (during 10 seconds), the trajectory of the ERK-RKIPP_RP pair still follows a very similar S-shape curve, and arrives close to the original B point but with slight lower ERK expression level (t=400). From the moment we re-clamp MEKPP to lower activation (1.0 for 10 seconds at t=400), we see an effect on ERK expression level where the final steady state B' gets shifted to the right.

Intervention 3: changing the kinematic parameters

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_param_names = ["k5"]
controlled_param_values = [[0.1]]

intervention_fn_3 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_3 = DictTree() 
for (param_name, param_value) in zip(controlled_param_names, controlled_param_values):
    param_idx = system_rollout.grn_step.c_indexes[param_name] #this time we specify intervention parameter value for the key "c"
    intervention_params_3.c[param_idx] = jnp.array(param_value) 
    print(f"Initial {param_name} param value: {system_rollout.c[param_idx]}, changed to {param_value}")


# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_3, intervention_params=intervention_params_3)
Figure 5: Simulation results for default initial condition A and modified kinematic parameter k5. Trajectory of (ERK, RKIPP_RP) is shown in transcriptional space.

👉 Here we can see changing the parameter k5 shifts the trajectory end point quite significantly from B to B', but that qualitatively the trajectory seems to keep a similar "S" shape.

Applying perturbations

Similarly, we would like to apply several perturbations on the system rollout to see how this influence the trajectory of the network in transcriptional space. In AutoDiscJax, a perturbation is implemented in the same manner than an intervention (by specifying a perturbation function that is also called at everystep during a rollout). The difference is only conceptual, as a perturbation is supposed to represent uncontrolled events (such as noise or other environmental stresses) whereas an intervention is supposed to represent controlled events (aka applied by the experimenter, such as drug stimuli).

Several perturbation functions are already provided in AutoDiscJax. In particular, the PiecewiseIntervention module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints. The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention or by adding to the variable's current value PiecewiseAddConstantIntervention).

Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X.

Perturbation 1: applying dynamical noise

# we apply noise between t=0 and t=80 secs to ALL node
start = 0 
end = 80
deltaT = system_rollout.deltaT
perturbed_node_ids = list(range(len(system_rollout.grn_step.y_indexes)))

# Create the noise perturbation generator module
noise_perturbation_generator_config = Dict()
noise_perturbation_generator_config.perturbation_type = "noise"
noise_perturbation_generator_config.perturbed_intervals = [[t-deltaT/2, t+deltaT/2] for t in jnp.linspace(start, end, 2 + int((end-start)/5))[1:-1]] #add noise every 5 sec between start and end
noise_perturbation_generator_config.perturbed_node_ids = perturbed_node_ids
noise_perturbation_generator_config.std = 0.01

noise_perturbation_generator, noise_perturbation_fn = create_perturbation_module(noise_perturbation_generator_config)


# Run the system with the perturbation
key, subkey = jrandom.split(key)
noise_perturbation_params, log_data = noise_perturbation_generator(subkey, default_system_outputs)
noise_system_outputs, log_data = system_rollout(subkey, None, None,
                                                        noise_perturbation_fn, noise_perturbation_params)
Figure 6: Simulation results for default initial condition A, with noise perturbation. (left) Evolution of gene expressions (with noise perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Perturbation 2: applying transient kicks

push_t = 3

# Create the push perturbation generator module
push_perturbation_generator_config = Dict()
push_perturbation_generator_config.perturbation_type = "push"
push_perturbation_generator_config.perturbed_intervals = [[push_t-deltaT/2, push_t+deltaT/2]]
push_perturbation_generator_config.perturbed_node_ids = observed_node_ids
push_perturbation_generator_config.magnitude = 0.1
push_perturbation_generator, push_perturbation_fn = create_perturbation_module(push_perturbation_generator_config)

# Run the system with the perturbation
key, subkey = jrandom.split(key)
push_perturbation_params, log_data = push_perturbation_generator(subkey, default_system_outputs)
push_system_outputs, log_data = system_rollout(subkey, None, None,
                                                        push_perturbation_fn, push_perturbation_params)