Tuto2 Tuto1 Paper

AI-driven Automated Discovery Tools Reveal Diverse Behavioral Competencies of Biological Networks

Authors Affiliation Published
Mayalen Etcheverry INRIA, Flowers team, Poietis September, 2023
Clément Moulin-Frier INRIA, Flowers team
Pierre-Yves Oudeyer INRIA, Flowers team
Michael Levin The Levin Lab, Tufts University Reproduce in Notebook

Introduction

TL;DR

This tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks.
It is intented to walk you through the set of tools we use to (i) automatically explore the space of input stimulis of a given biological network in order to discover a diversity of behaviors in these systems (ii) analyze the robustness of the discovered abilities in order to infer the network's navigation competencies in the transcriptional space of gene activation (iii) explore possible reuses of the constructed "behavioral catalog" in a biomedical context

📝 How to follow this tutorial

💻 AutoDiscJax

Throughout this tutorial, we will be using the AutoDiscJax library, a library built on top of jax and equinox to facilitate automated experimentation and simulation of biological network pathways.

AutoDiscJax follows two main design principles: 1) Everything is a module, where a module is simply a parametrized function that takes inputs and returns outputs (and log_data). All autodiscjax modules adx.Module are implemented as equinox modules eqx.Module, which essentially allows to represent the function as a callable PyTree (and hence to be compatible with jax transformations) while keeping an intuitive API for model building (python class with a _call_ method). The only add-on with respect to equinox is that when instantiating a adx.Module, the user must specify the module's outputs PyTree structure, shape and dtype. 2) An experiment pipeline defines (i) how modules interact sequentially and exchange information, and (ii) what information should be collected and saved in the experiment history.

AutoDiscJax provides a handful of already-implement modules and pipelines to 1) Simulate biological networks while intervening on them according to our needs (see Part 1) 2) Automatically organize experimentation in those systems, by implementing a variety of exploration approaches such as random, optimization-driven and curiosity-driven search (see Part 2 and Part 3) 3) Analyze the discoveries of the exploration method, for instance by testing their robustness to various perturbations (see Part 4)

Finally, AutoDiscJax takes advantage of JAX mains features (just-in-time compilation, automatic vectorization and automatic differentation) which are especially advantageous for parallel experimentation and computational speedups, as well as gradient-based optimization.

Part 1: Numerical simulation (with interventions) of a GRN model

Ordinary differential equations (ODE) models are widely used to represent the behavior of complex biological processes. These ODE models are often experimentally determined as well as curated by biologists, e.g. combining experimental and data-drive methods to determine whether and how pairs of biomolecules interact. The resulting mathematical models are usually stored and exchanged using the Systems Biology Markup Language (SBML) language. Thanks to community efforts, large collections of published ODE models have been made publicly available on online databases, such as the BioModels website.

For this tutorial, we will be studying the gene regulatory network (GRN) model that describes the influence of RKIP on the ERK Signaling Pathway, that is described by this Cho et al.'s paper, and that is hosted on the BioModels database with the BIOMD0000000647 identifier and described by the following reaction graph: Drawing

Whereas the SBML file provides information about the different species, parameters and reactions involved in this model, we must use our own tools for carrying out simulation studies.

In the first part of this tutorial, we will be showing how to use the SBMLtoODEjax library to automatically retrieve and parse the SBML file into a python file, and the AutoDiscJax library to easily simulate the model and manipulate it based on our needs.

biomodel_idx = 647
observed_node_names = ["ERK", "RKIPP_RP"]

SBML to python conversion

Let's first use the SBMLtoODEjax library to download the SBML file and generate the corresponding python class that will allow us to simulate the model's dynamics.

out_model_sbml_filepath = f"data/biomodel_{biomodel_idx}.xml"
out_model_jax_filepath = f"data/biomodel_{biomodel_idx}.py"

# Donwload the SBML file
if not os.path.exists(out_model_sbml_filepath):
    model_xml_body = sbmltoodejax.biomodels_api.get_content_for_model(biomodel_idx)
    with open(out_model_sbml_filepath, 'w') as f:
        f.write(model_xml_body)

# Generation of the python class from the SBML file
if not os.path.exists(out_model_jax_filepath):
    model_data = sbmltoodejax.parse.ParseSBMLFile(out_model_sbml_filepath)
    sbmltoodejax.modulegeneration.GenerateModel(model_data, out_model_jax_filepath)

At this point, you should have created a biomodel_647.py file that contains several variables and functional modules as follows:

The variables are instantiated as such as for many models the kinematic parameters are all constant (w$=\emptyset$) but for others, there are some kinematic parameters w(t) that evolve in time in addition to the constant ones c.

For those who are interested to undersant more about the specifications of SBML files and SBMLtoODEjax files, we refer to the SBMLtoODEjax documentation: https://developmentalsystems.org/sbmltoodejax/design_principles.html

System Rollout module

Ok, let's now create our first module: the System Rollout. In short, the system rollout can be seen as a wrapper of the previously-created smbltoodejax's ModelRollout module, but this time allowing to apply all sorts of interventions and/or perturbations during the rollout simulation, and without having to modify the rollout codebase. We'll see how to do that in short, but first let's create our system rollout module and see how a simulation looks like.

When instancing the ModelStep module, we should specify the desired time step $\Delta t$ and ODE solver parameters $(atol, rtol, mxstep)$

# System Rollout Config
system_rollout_config = Dict()
system_rollout_config.system_type = "grn"
system_rollout_config.model_filepath = out_model_jax_filepath # path of model class that we just created using sbmltoodejax
system_rollout_config.atol = 1e-6 # parameters for the ODE solver 
system_rollout_config.rtol = 1e-12
system_rollout_config.mxstep = 1000
system_rollout_config.deltaT = 0.1 # the ODE solver will compute values every 0.1 second
system_rollout_config.n_secs = 1000 # number of a seconds of one rollout in the system 
system_rollout_config.n_system_steps = int(system_rollout_config.n_secs/system_rollout_config.deltaT) # total number of steps returned after a rollout

# Create the module
system_rollout = create_system_rollout_module(system_rollout_config)

# Get observed node ids
observed_node_ids = [system_rollout.grn_step.y_indexes[observed_node_names[0]], system_rollout.grn_step.y_indexes[observed_node_names[1]]]

Let's now simulate the module. By default, the rollout starts with the initial concentrations as given in the SBML file.

key, subkey = jrandom.split(key)
default_system_outputs, log_data = system_rollout(subkey) # all autodiscjax modules takes the state of the "random seed" as input (needed e.g. for operating random sampling operations within the module) 

Here is what we obtain when simulating the network with the default initial condtions (see original Figure 5 of Cho el al's paper):

Figure 1: Simulation results of the mathematical modeling for default initial condition.

We can also observe the resulting trajectories in phase space, also called transcriptional phase for gene regulatory networks. Let's for instance say that we are interested in observing the activation response of two specific nodes: ERK and RKIPP_RP. By plotting their trajectory in the transcriptional space, we can see that they start in point A(0,0) and navigate until they reach a steady state in point B(0.036,0.055).

Figure 2: Simulation results for default initial condition. Trajectory of nodes ERK and RKIPP_RP is shown in transcriptional space.

Applying interventions

We would now like to do several interventions on the system rollout to see how this influence the trajectory of the network in transcriptional space. The system rollout module allows us to do it by specifying an intervention function $(y^-, y, w^-, w, c^-, c, t^-, t) \mapsto (y, w, c, t)$ that is called at everystep during a rollout, where $(y^-,w^-,c^-,t^-)$ represent the variables values before the step (t-1) and $(y,w,c,t)$ represent the variable values that would be returned by the ModelStep function without intervention (and that are overwrited by the intervention function).

Several intervention functions are already provided in AutoDiscJax. In particular, the PiecewiseIntervention module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints. The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention or by adding to the variable's current value PiecewiseAddConstantIntervention).

Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X

Intervention 1: changing the initial species amount

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_values = [[0.03], [0.02]]

intervention_fn_1 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_1 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_1.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_1, intervention_params=intervention_params_1)
Figure 3: Simulation results for modified initial condition A'. Trajectory is shown for nodes ERK and RKIPP_RP in transcriptional space.

👉 Interestingly we can see that despite enforcing the network to start at another point A', the network still converges to the same point B. Moreover, instead of directly going from A' to B, the network seems to do a "detour" to reach back its initial trajectory, and then follow again the trajectory succesfully until point B.

Intervention 2: clamping the species amount to specific values

# Create the intervention
controlled_intervals = [[0, 10], [400, 410]]
controlled_node_names = ["MEKPP"]
controlled_node_values = [[2.5, 1.0]]

intervention_fn_2 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_2 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    intervention_params_2.y[node_idx] = jnp.array(node_value)

# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_2, intervention_params=intervention_params_2)
Figure 4: Simulation results for default initial condition A, with clamping of node MEKPP. (left) Evolution of RKIP (with clamp interventions) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

👉 Here, the clamping of MEKPP seems to have some effect on ERK but not on RKIPP_RP. Indeed, after the initial clamping of MEKPP to 2.5 (during 10 seconds), the trajectory of the ERK-RKIPP_RP pair still follows a very similar S-shape curve, and arrives close to the original B point but with slight lower ERK expression level (t=400). From the moment we re-clamp MEKPP to lower activation (1.0 for 10 seconds at t=400), we see an effect on ERK expression level where the final steady state B' gets shifted to the right.

Intervention 3: changing the kinematic parameters

# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_param_names = ["k5"]
controlled_param_values = [[0.1]]

intervention_fn_3 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_3 = DictTree() 
for (param_name, param_value) in zip(controlled_param_names, controlled_param_values):
    param_idx = system_rollout.grn_step.c_indexes[param_name] #this time we specify intervention parameter value for the key "c"
    intervention_params_3.c[param_idx] = jnp.array(param_value) 
    print(f"Initial {param_name} param value: {system_rollout.c[param_idx]}, changed to {param_value}")


# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_3, intervention_params=intervention_params_3)
Figure 5: Simulation results for default initial condition A and modified kinematic parameter k5. Trajectory of (ERK, RKIPP_RP) is shown in transcriptional space.

👉 Here we can see changing the parameter k5 shifts the trajectory end point quite significantly from B to B', but that qualitatively the trajectory seems to keep a similar "S" shape.

Applying perturbations

Similarly, we would like to apply several perturbations on the system rollout to see how this influence the trajectory of the network in transcriptional space. In AutoDiscJax, a perturbation is implemented in the same manner than an intervention (by specifying a perturbation function that is also called at everystep during a rollout). The difference is only conceptual, as a perturbation is supposed to represent uncontrolled events (such as noise or other environmental stresses) whereas an intervention is supposed to represent controlled events (aka applied by the experimenter, such as drug stimuli).

Several perturbation functions are already provided in AutoDiscJax. In particular, the PiecewiseIntervention module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints. The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention or by adding to the variable's current value PiecewiseAddConstantIntervention).

Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X.

Perturbation 1: applying dynamical noise

# we apply noise between t=0 and t=80 secs to ALL node
start = 0 
end = 80
deltaT = system_rollout.deltaT
perturbed_node_ids = list(range(len(system_rollout.grn_step.y_indexes)))

# Create the noise perturbation generator module
noise_perturbation_generator_config = Dict()
noise_perturbation_generator_config.perturbation_type = "noise"
noise_perturbation_generator_config.perturbed_intervals = [[t-deltaT/2, t+deltaT/2] for t in jnp.linspace(start, end, 2 + int((end-start)/5))[1:-1]] #add noise every 5 sec between start and end
noise_perturbation_generator_config.perturbed_node_ids = perturbed_node_ids
noise_perturbation_generator_config.std = 0.01

noise_perturbation_generator, noise_perturbation_fn = create_perturbation_module(noise_perturbation_generator_config)


# Run the system with the perturbation
key, subkey = jrandom.split(key)
noise_perturbation_params, log_data = noise_perturbation_generator(subkey, default_system_outputs)
noise_system_outputs, log_data = system_rollout(subkey, None, None,
                                                        noise_perturbation_fn, noise_perturbation_params)
Figure 6: Simulation results for default initial condition A, with noise perturbation. (left) Evolution of gene expressions (with noise perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Perturbation 2: applying transient kicks

push_t = 3

# Create the push perturbation generator module
push_perturbation_generator_config = Dict()
push_perturbation_generator_config.perturbation_type = "push"
push_perturbation_generator_config.perturbed_intervals = [[push_t-deltaT/2, push_t+deltaT/2]]
push_perturbation_generator_config.perturbed_node_ids = observed_node_ids
push_perturbation_generator_config.magnitude = 0.1
push_perturbation_generator, push_perturbation_fn = create_perturbation_module(push_perturbation_generator_config)

# Run the system with the perturbation
key, subkey = jrandom.split(key)
push_perturbation_params, log_data = push_perturbation_generator(subkey, default_system_outputs)
push_system_outputs, log_data = system_rollout(subkey, None, None,
                                                        push_perturbation_fn, push_perturbation_params)
Figure 7: Simulation results for default initial condition A, with push perturbation. (left) Evolution of ERK and RKIPP_RP (with push perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

Perturbation 3: applying wall obstacles

# Create the wall perturbation generator module
wall_perturbation_generator_config = Dict()
wall_perturbation_generator_config.perturbation_type = "wall"
wall_perturbation_generator_config.wall_type = "force_field"
wall_perturbation_generator_config.perturbed_intervals = [[0, system_rollout_config.n_secs]]
wall_perturbation_generator_config.perturbed_node_ids = observed_node_ids
wall_perturbation_generator_config.n_walls = 2
wall_perturbation_generator_config.walls_intersection_window = [[0.1, 0.15], [0.85, 0.9]]  # in distance travelled from 0 (start point A) to 1.0 (end point B)
wall_perturbation_generator_config.walls_length_range = [[0.1, 0.1], [0.1, 0.1]]
wall_perturbation_generator_config.walls_sigma = [1e-2, 1e-4]
wall_perturbation_generator, wall_perturbation_fn = create_perturbation_module(wall_perturbation_generator_config)

# Run the system with the perturbation
key, subkey = jrandom.split(key)
wall_perturbation_params, log_data = wall_perturbation_generator(subkey, default_system_outputs)
wall_system_outputs, log_data = system_rollout(subkey, None, None,
                                                        wall_perturbation_fn, wall_perturbation_params)
Figure 8: Simulation results for default initial condition A, with wall perturbation. (left) Evolution of ERK and RKIPP_RP (with wall perturbation) through reaction time. (right) Trajectory of (ERK, RKIPP_RP) in transcriptional space.

👉 Again robust

Simulations in batch mode

To finish Part 1 of this tutorial, let's see how AutoDiscJax allows us to perform simulations in parallel.

# Put the system in batch mode
batched_system_rollout = vmap(system_rollout, in_axes=(0, None, 0))

# Create the M=10 interventions (vector of starting positions between minval and maxval)
M = 10
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_minvals = [0.02, 0.02]
controlled_node_maxvals = [0.08, 0.08]
batched_interventions_params_1 = DictTree()
for (node_name, node_minval, node_maxval) in zip(controlled_node_names, controlled_node_minvals, controlled_node_maxvals):
    node_idx = system_rollout.grn_step.y_indexes[node_name]
    key, subkey = jrandom.split(key)
    batched_interventions_params_1.y[node_idx] = jrandom.uniform(subkey, shape=(M, 1), minval=node_minval, maxval=node_maxval)

key, *subkeys = jrandom.split(key, num=M + 1)
batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn_1, batched_interventions_params_1)

print(default_system_outputs.ys.shape, batched_system_outputs.ys.shape)

Using jax vmap transformation, the above code automatically vectorizes the call to the system rollout module over different intervention parameters and stores the vectorized results in the batched_system_outputs output variable. This is very a convenient (and fast) way to test several interventions in the biological network.

Figure 9: Simulation results for different initial condition A0, ..., A9, launched in batch mode. Obtained trajectories of (ERK, RKIPP_RP) are shown in transcriptional space.

This plot shows the M=10 resulting trajectories obtained in the transcriptional space when applying intervention1 for 10 different starting conditions $A0, \dots, A9$.

👉 We can see that despite starting the simulations in 10 different positions (initial amounts of ERK and RKOPP_RP), they all converge to the same steady state point B.

Part 2: Automated experimentation approaches and challenges

batch_size = 100

random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_node_ids = list(range(len(default_system_outputs.ys)))
random_intervention_generator_config.controlled_intervals = [[0, system_rollout.deltaT/2.0]]

random_search_discoveries = {}
for r in [1, 10, 100]:
    controlled_node_minvals = default_system_outputs.ys.min(-1)/r
    controlled_node_maxvals = default_system_outputs.ys.max(-1)*r

    intervention_params_tree = DictTree()
    for y_idx in random_intervention_generator_config.controlled_node_ids:
        intervention_params_tree.y[y_idx] = "placeholder"

    random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
    random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),),
                                             intervention_params_tree)
    random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)

    random_intervention_generator_config.low = DictTree()
    random_intervention_generator_config.high = DictTree()
    for (node_idx, node_minval, node_maxval) in zip(random_intervention_generator_config.controlled_node_ids, controlled_node_minvals, controlled_node_maxvals):
        random_intervention_generator_config.low.y[node_idx] = jnp.array([node_minval])
        random_intervention_generator_config.high.y[node_idx] = jnp.array([node_maxval])

    random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
    batched_random_intervention_generator = vmap(random_intervention_generator)

    # Generate random interventions (batch mode)
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_interventions_params, log_data = batched_random_intervention_generator(jnp.array(subkeys))

    # Rollout the system (batch mode)
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_interventions_params)

    # Store the discoveries
    random_search_discoveries[r] = batched_system_outputs
Figure 10: Random Search discoveries in transcriptional space, for different input parameter range r (r=1, r=10, r=100) each with N=100 exploration runs. (ERK, RKIPP_RP) endpoints (t=1000 secs) of the discovered trajectories are shown in transcriptional space.

👉 with too constrained range (r=1), random search misses a big part of what is feasible (in terms of reachable transcriptional space). Choosing a looser input parameter space range (e.g. r=10), shows that it is possible to find novel steady states (e.g. with ERK > 5). However, this also means the the exploration space is bigger and hence harder to explore. In fact, we can see that for r=10 and r=100, random search is not very efficient as most of the discoveries are localized on the ERK=0 axis (i.e. fall in the the ERK=0 valley/attractor).

# Define loss function (L2 distance to target point)
def evaluate_worker_fn(key, intervention_params, intervention_fn, system_rollout, observed_node_ids, target_point, low, high):

    # rollout the system with parameters
    key, subkey = jrandom.split(key)
    system_outputs, log_data = system_rollout(subkey, intervention_fn, intervention_params)

    # Get trajectory final point
    reached_point = system_outputs.ys[jnp.array(observed_node_ids), -1]

    # Calc L2 distance to target point
    loss = jnp.sqrt((jnp.square((reached_point - target_point)/(high-low))).sum())

    # Append info to log data
    log_data = DictTree()
    log_data.reached_point = reached_point
    log_data.loss = loss

    return loss, log_data

previously_reached_points = random_search_discoveries[100].ys[:, jnp.array(observed_node_ids), -1]
low = jnp.nanmin(previously_reached_points, axis=0)
high = jnp.nanmax(previously_reached_points, axis=0)

target_point_1 = jnp.array([200., 4.])
evaluate_worker_fn_1 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout, 
                                   observed_node_ids=observed_node_ids, target_point=target_point_1, low=low, high=high)

target_point_2 = jnp.array([200., 10.])
evaluate_worker_fn_2 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout, observed_node_ids=observed_node_ids, target_point=target_point_2, low=low, high=high)


# Create SGD Optimizer
intervention_optimizer_config = Dict()
intervention_optimizer_config.n_optim_steps = 100
intervention_optimizer_config.n_workers = 1
intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda node: 0.0, random_intervention_generator.low)
intervention_optimizer_config.lr = jtu.tree_map(lambda low, high: 0.01*(high-low), 
                                                random_intervention_generator.low, random_intervention_generator.high)

optimizer = optimizers.SGDOptimizer(random_intervention_generator.out_treedef,
                                    random_intervention_generator.out_shape,
                                    random_intervention_generator.out_dtype,
                                    random_intervention_generator.low,
                                    random_intervention_generator.high,
                                    intervention_optimizer_config.n_optim_steps,
                                    intervention_optimizer_config.n_workers,
                                    intervention_optimizer_config.init_noise_std,
                                    intervention_optimizer_config.lr
                                )
# Start position 1
start_intervention_params_1 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    start_intervention_params_1.y[node_idx] = default_system_outputs.ys[node_idx, 0][jnp.newaxis]

# Start position 2
selected_intervention_ids, distances = nearest_neighbors(target_point_1/(high-low), previously_reached_points/(high-low), k=1)
start_intervention_params_2 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    start_intervention_params_2.y[node_idx] = random_search_discoveries[100].ys[selected_intervention_ids[0], node_idx, 0][jnp.newaxis]
# Optimization Run 1
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_1, log_data_1 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_1)

# Optimization Run 2
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_2, log_data_2 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_1)

# Optimization Run 3
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_3, log_data_3 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_2)

# Optimization Run 4
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_4, log_data_4 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_2)
Figure 11: (left) L2 loss to target (normalized). (right) Training progress trajectories in transcriptional space.

👉 Interestingly we can see that all optimization runs make progress (training loss decreases) but following specific curves in the transcriptional space (not straight toward the targets). Those curves follow the valleys of the optimization landscape, and we can clearly understand how the choice of the starting point will condition the optimization success, as the optimization run might get stuck in a valley or in a local minima.

For instance here, the default initial point B1 (as given in the SBML file) fails to achieve both targets (G1 and G2): his training loss is reaching a plateau (blue and green curve) and the optimization gets stuck, or at least is making very slow progress (as can be seen by the blue and green trajectories in transcriptional space). Another startint point $B_2$, that was previously found by random search (and selected because it was the closest to G1 among all discoveries), successfully manages to get very close to the target goal G1 (orange trajectory). However, it fails to reach another target point G2 (red trajectory).

👉 This shows the importance of having a good pool of initial discoveries for reaching desired targets with optimization. Because random search is not very efficient in covering the map of possible steady states, it is likely that optimization will fail for many possible targets G. Can we find a more efficient way to populate the pool of discoveries, given the same experimental budget?80

Part 3: Curiosity-driven search as an efficient automated discovery tool

Random Intervention Generator

# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)
print(jtu.tree_map(lambda node: node.shape, intervention_params))

# example in batch mode
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
interventions_params, log_data =  batched_random_intervention_generator(jnp.array(subkeys))
print(jtu.tree_map(lambda node: node.shape, interventions_params))

Goal Embedding Encoder

goal_embedding_encoder_config = Dict()
goal_embedding_encoder_config.encoder_type = "filter"
goal_embedding_tree = "placeholder"
goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids), ), goal_embedding_tree)
goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: system_outputs.ys[..., observed_node_ids, -1])

goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)
batched_goal_embedding_encoder = vmap(goal_embedding_encoder)
# example: encode the default system outputs 
key, subkey = jrandom.split(key)
reached_goal_embedding, log_data = goal_embedding_encoder(subkey, default_system_outputs)
print(reached_goal_embedding)

# example in batch mode: encode system outputs discovered by random search
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
reached_goals_embeddings, log_data = batched_goal_embedding_encoder(jnp.array(subkeys), random_search_discoveries[100])
print(reached_goals_embeddings.shape)

Goal-conditioned Achievement Loss

goal_achievement_loss_config = Dict()
goal_achievement_loss_config.loss_type = "L2" 

goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)
batched_goal_achievement_loss = vmap(goal_achievement_loss)
# example
target_goal_embedding = target_point_1
key, subkey = jrandom.split(key)
gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
print(gc_loss)

# example in batch mode
target_goals_embeddings = jnp.tile(target_goal_embedding[jnp.newaxis], (batch_size, 1))
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
gc_losses, log_data = batched_goal_achievement_loss(jnp.array(subkeys), reached_goals_embeddings, target_goals_embeddings)
print(gc_losses.shape)

Goal Generator

goal_generator_config = DictTree()
goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
goal_generator_config.low = 0.0
goal_generator_config.high = None
goal_generator_config.generator_type = "hypercube"
goal_generator_config.hypercube_scaling = 1.3

goal_generator = create_goal_generator_module(goal_generator_config)
batched_goal_generator = vmap(goal_generator, in_axes=(0, None, None))
# example 
key, subkey = jrandom.split(key)
next_target_goal, log_data = goal_generator(subkey, target_goals_embeddings, reached_goals_embeddings)
print(next_target_goal)

# example in batch mode
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
next_target_goals_embeddings, log_data = batched_goal_generator(jnp.array(subkeys), target_goals_embeddings, reached_goals_embeddings)
print(next_target_goals_embeddings.shape)
Figure 12: IMGEP: uniform goal sampling in the (scaled) hyperrectangle of previously reached goals.

Goal-conditioned Intervention Selector

gc_intervention_selector_config = Dict()
gc_intervention_selector_config.selector_type="nearest_neighbor"
gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
gc_intervention_selector_config.k = 1

gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)
batched_gc_intervention_selector = vmap(gc_intervention_selector, in_axes=(0, 0, None))
# example
key, *subkeys = jrandom.split(key, num=batch_size + 1)
source_interventions_ids, log_data = batched_gc_intervention_selector(jnp.array(subkeys), next_target_goals_embeddings, reached_goals_embeddings)
print(source_interventions_ids.shape)
Figure 13: IMGEP: goal-conditioned nearest neighbor intervention selection.

Goal-conditioned Intevention Optimizer

gc_intervention_optimizer_config = Dict()
gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
gc_intervention_optimizer_config.low = random_intervention_generator.low
gc_intervention_optimizer_config.high = random_intervention_generator.high
gc_intervention_optimizer_config.optimizer_type = "EA"
gc_intervention_optimizer_config.n_optim_steps = 1
gc_intervention_optimizer_config.n_workers = 1
gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low), 
                                                               gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)

gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
                                        perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
                                        intervention_fn=intervention_fn, system_rollout=system_rollout,
                                        goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
                                        rollout_statistics_encoder=null_rollout_statistics_encoder
                                        )
batched_gc_intervention_optimizer = vmap(partial_gc_intervention_optimizer, in_axes=(0, 0, 0, None))
# example: optimize init species amount (from default ones) to reach a ERK-RKIPP_RP steady state B (150,5)
key, subkey = jrandom.split(key)
optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, start_intervention_params_1, target_point_1, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))

# example in batch mode: optimize the selected interventions (closest to target) toward their respective targets
previous_interventions_params = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
    previous_interventions_params.y[node_idx] = random_search_discoveries[100].ys[:, node_idx, 0]

start_interventions_params = jtu.tree_map(lambda x: x[source_interventions_ids], previous_interventions_params)

key, *subkeys = jrandom.split(key, num=batch_size + 1)
optimized_interventions_params, log_data = batched_gc_intervention_optimizer(jnp.array(subkeys), start_interventions_params, next_target_goals_embeddings, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_interventions_params))
Figure 14: IMGEP: goal-conditioned intervention optimization (here local diffusion).

👉 Progress toward G3, G7. No progress toward G9 (already quite close to it), but even if new point Z9 if further from G9 it falls in an uncovered area and will be useful for future goals.

Figure 15: IMGEP: discoveries after one iteration.

👉 with this simple example we can grasp already why the IMGEP will be much more efficient in finding diverse possible final states

Run experiment pipeline

   

# Run IMGEP
jax_platform_name = "cpu"
seed = 0
n_random_batches = 2 
n_imgep_batches = 8
batch_size = 20
imgep_experiment_data_save_folder = "data/imgep_data"
if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         imgep_experiment_data_save_folder,
                         random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         goal_generator, gc_intervention_selector, gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)

# Run Random Search
rs_experiment_data_save_folder = "data/rs_data"
if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
    run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size, 
                      rs_experiment_data_save_folder,
                      random_intervention_generator, intervention_fn,
                      null_perturbation_generator, null_perturbation_fn,
                      system_rollout, null_rollout_statistics_encoder,
                      out_sanity_check=False, save_modules=False, save_logs=False)
imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
print(imgep_reached_goals_embeddings.shape)
rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle"))
rs_reached_goals_embeddings = rs_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]
print(rs_reached_goals_embeddings.shape)

👉 Sample efficiency

Intrinsically motivated goal exploration algorithms are designed to autonomously discover the widest range of possible diverse effects that can be produced in an initially unknown system (here our biological network). Thus, a first way to evaluate the exploration algorithm is to measure how well and how fast they cover the state space, and particularly in comparison to random search.

epsilon = 0.033

analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))

def calc_analytic_bc_coverage(reached_endpoints, epsilon):        
    for step_idx, reached_endpoint in enumerate(reached_endpoints):
        if step_idx == 0:
            union_polygon = Point(reached_endpoints[0]).buffer(epsilon)
            covered_areas = [union_polygon.area]
        else:
            union_polygon = unary_union([union_polygon, Point(reached_endpoint).buffer(epsilon)])
            covered_areas.append(union_polygon.area)

    return union_polygon, covered_areas

imgep_reached_endpoints = (imgep_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
imgep_union_polygon, imgep_covered_areas = calc_analytic_bc_coverage(imgep_reached_endpoints, epsilon=epsilon)
rs_reached_endpoints = (rs_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
rs_union_polygon, rs_covered_areas = calc_analytic_bc_coverage(rs_reached_endpoints, epsilon=epsilon)
Figure 16: Diversity of behaviors discovered by the different algorithms variants. (left) All discovered behaviors (stable endpoints). (middle) Discovered reachable space (union of epsilon-radius balls centered around the discovered endpoints) by random search (pink) and imgep (blue). Results are shown for espilon=0.033. (right) Diversity of behaviors discovered throughout exploration, where the area of the discovered reachable space is used as diversity measure.

👉 Finding Salient Stimuli

# calc clusters in behavior space
clusterer = hdbscan.HDBSCAN(min_cluster_size=10, cluster_selection_epsilon=0.1)
imgep_clusters_labels = clusterer.fit_predict(imgep_reached_endpoints)
rs_clusters_labels = clusterer.fit_predict(rs_reached_endpoints)

# project sampled params in 2D space with TSNE
imgep_sampled_params = jnp.array(list(imgep_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
rs_sampled_params = jnp.array(list(rs_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
all_sampled_params = jnp.concatenate([imgep_sampled_params, rs_sampled_params])

tsne = TSNE(n_components=2)
all_sampled_params = tsne.fit_transform(all_sampled_params)
imgep_sampled_params, rs_sampled_params = all_sampled_params[:200], all_sampled_params[200:]
#@title [Figure 17]
fig_idx = 17

if nb_mode == "run":

    fig = make_subplots(rows=2, cols=2, horizontal_spacing=0.01, vertical_spacing=0.1, subplot_titles=["<b>(a) Curiosity Search</b><br> ","","<b>(b) Random Search</b><br> ",""])
    fig.update_layout(**default_layout, margin_t=40)
    fig.update_annotations(font_size=12, x=0.1)

    # Add points / shape contour per IMGEP cluster
    for row_idx, (clusters_labels, reached_endpoints, sampled_params) in enumerate(zip([imgep_clusters_labels, rs_clusters_labels],
                                                                                      [imgep_reached_endpoints, rs_reached_endpoints],
                                                                                      [imgep_sampled_params, rs_sampled_params],
                                                                                     )):
        for label_idx in set(clusters_labels):
            label_idx = label_idx.item()

            cluster_point_ids = jnp.where(clusters_labels==label_idx)[0]
            z_points = reached_endpoints[cluster_point_ids]
            i_points = sampled_params[cluster_point_ids]

            if label_idx < 0:
                show_shape = False
                marker_size = 4
                color_idx = 7
                name = f"N/A"
            else:
                show_shape = True
                marker_size = 4
                color_idx = [4,3,9,6][label_idx] if row_idx==0 else [2,8][label_idx]
                name = f"cluster {label_idx+1}"

            # shape contour
            if show_shape:
                for col_idx, (points, eps) in enumerate(zip([i_points, z_points], [1.5, 0.05])):
                    poly = unary_union([Point(point).buffer(eps) for point in points])
                    poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
                    x, y = [], []
                    if poly.geom_type == 'MultiPolygon':
                        for geom in poly.geoms:
                            geom_x, geom_y = geom.exterior.coords.xy
                            x.append(np.array(geom_x))
                            y.append(np.array(geom_y))
                    elif poly.geom_type == 'Polygon':
                        geom_x, geom_y = poly.exterior.coords.xy
                        x.append(np.array(geom_x))
                        y.append(np.array(geom_y))
                    for i, (x, y) in enumerate(zip(x,y)):
                        fig.add_trace(go.Scatter(x=x, y=y, fill="toself", fillcolor=default_colors_shade[color_idx],
                                                 name=name, legendgroup=row_idx, showlegend=(i==0)&(col_idx==0),
                                                 line=dict(color=default_colors[color_idx]), hoverinfo="skip"), row=row_idx+1, col=col_idx+1)

            # points
            fig.add_trace(go.Scatter(x=z_points[:,0], y=z_points[:,1],
                                     name=name, legendgroup=row_idx, showlegend=(label_idx==-1),
                                     mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size), 
                          row=row_idx+1, col=2)
            fig.add_trace(go.Scatter(x=i_points[:,0], y=i_points[:,1], 
                                     name=name, legendgroup=row_idx, showlegend=False,
                                     mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size), 
                          row=row_idx+1, col=1)


    # Add background shape
    for row_idx in [1,2]:
        fig.add_vrect(x0=-0.1, x1=1.05, 
                      fillcolor="#BAE1FF", opacity=.8,
                      line=dict(color="#6C8EBF", width=2), 
                      annotation_text="Behavior Space <i>Z</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
                      layer="below", row=row_idx, col=2)
        fig.add_vrect(x0=-40, x1=30, 
                      fillcolor="#F6E785", opacity=.8,
                      line=dict(color="#C79714", width=2), 
                      annotation_text="Intervention Space <i>I</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
                      layer="below", row=row_idx, col=1)


    # Update Layout 
    fig.update_layout(legend_tracegroupgap=320)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)


    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto1_fig_{fig_idx}.json")


elif nb_mode == "load":
    fig = plotly.io.read_json(f"figures/tuto1_fig_{fig_idx}.json")


# Display Fig
width, height = 900, 800
t = f"Mapping between Intervention Space and Behavior space. " 

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    display(Markdown(img_title))
Figure 17: Mapping between Intervention Space and Behavior space.

👉 Adaptive Design Choices

Finally we show that IMGEP facilitate the engineer task (e.g. when we are not sure on the apppropriate range for the parameter and/or goal space)

Adaptive Goal Space

Figure 18: IMGEP: adaptive goal space extent.

Adaptive Parameter Space Variant

r= 1

# Random generator initially constrained to the tight range (r=1) of what we know is feasible from the default rollout
# but this time only for the 10% first runs, then we let the algorithm adapt their parameter space extent
n_random_batches = 1
n_imgep_batches = 9
constrained_random_intervention_generator_config = deepcopy(random_intervention_generator_config)
for node_idx in random_intervention_generator_config.controlled_node_ids:
    constrained_random_intervention_generator_config.low.y[node_idx] = default_system_outputs.ys[node_idx].min() / r
    constrained_random_intervention_generator_config.high.y[node_idx] = default_system_outputs.ys[node_idx].max() * r   
constrained_random_intervention_generator, intervention_fn = create_intervention_module(constrained_random_intervention_generator_config)

# Goal-conditionned optimizer: remove (low, high) constraints 
## and set local mutation amplitude based on what we know is feasible (r=1) from the default rollout
adaptive_gc_intervention_optimizer_config = deepcopy(gc_intervention_optimizer_config)
adaptive_gc_intervention_optimizer_config.low = jtu.tree_map(lambda node: jnp.zeros_like(node), gc_intervention_optimizer_config.low)
adaptive_gc_intervention_optimizer_config.high = None
adaptive_gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: r*(high-low), 
                                                                        constrained_random_intervention_generator.low,
                                                                        constrained_random_intervention_generator.high)
adaptive_gc_intervention_optimizer = create_gc_intervention_optimizer_module(adaptive_gc_intervention_optimizer_config)

# Run Adaptive IMGEP variant
adaptive_imgep_experiment_data_save_folder = "data/adaptive_imgep_data"
if not os.path.exists(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         adaptive_imgep_experiment_data_save_folder,
                         constrained_random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         goal_generator, gc_intervention_selector, adaptive_gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)


# Run Adaptive RandomMut variant (IMGEP with null goal generation and random intervention selection)
null_goal_generator_config = deepcopy(goal_generator_config)
null_goal_generator_config.hypercube_scaling = 0.0
null_goal_generator = create_goal_generator_module(null_goal_generator_config)

random_intervention_selector_config = deepcopy(gc_intervention_selector_config)
random_intervention_selector_config.selector_type = "random"
random_intervention_selector = create_gc_intervention_selector_module(random_intervention_selector_config)

adaptive_rmut_experiment_data_save_folder = "data/adaptive_rmut_data"
if not os.path.exists(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         adaptive_rmut_experiment_data_save_folder,
                         constrained_random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         null_goal_generator, random_intervention_selector, adaptive_gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)
adaptive_imgep_experiment_history = DictTree.load(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle"))
adaptive_imgep_reached_goals_embeddings = adaptive_imgep_experiment_history.reached_goal_embedding_library
adaptive_rmut_experiment_history = DictTree.load(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle"))
adaptive_rmut_reached_goals_embeddings = adaptive_rmut_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]

adaptive_analytic_bc_space_low = jnp.minimum(jnp.nanmin(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmin(adaptive_rmut_reached_goals_embeddings, 0))
adaptive_analytic_bc_space_high = jnp.maximum(jnp.nanmax(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmax(adaptive_rmut_reached_goals_embeddings, 0))

adaptive_imgep_reached_endpoints = (adaptive_imgep_reached_goals_embeddings-adaptive_analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_imgep_union_polygon, adaptive_imgep_covered_areas = calc_analytic_bc_coverage(adaptive_imgep_reached_endpoints, epsilon=epsilon)
adaptive_rmut_reached_endpoints = (adaptive_rmut_reached_goals_embeddings-analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_rmut_union_polygon, adaptive_rmut_covered_areas = calc_analytic_bc_coverage(adaptive_rmut_reached_endpoints, epsilon=epsilon)
Figure 19: Diversity of behaviors discovered by the adaptive parameter space algorithms variants. (left) All discovered behaviors (stable endpoints). (middle) Discovered reachable space (union of epsilon-radius balls centered around the discovered endpoints) by random search (pink) and imgep (blue). Results are shown for espilon=0.033. (right) Diversity of behaviors discovered throughout exploration, where the area of the discovered reachable space is used as diversity measure.

Part 4: Empirical tests for analyzing navigation competencies

# We test only 40 last trajectories (to go faster)
eval_system_outputs_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.system_output_library)
eval_intervention_params_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.intervention_params_library)

Trajectory characteristics

def calc_settling_time(dist_vals, settling_time_threshold):
    # assume normalized dist_vals starting from 1 and finishing at 0
    settling_time = jnp.where(~(dist_vals < settling_time_threshold), size=len(dist_vals), fill_value=-1)[0].max()
    return settling_time


def calc_travelling_time(trajectories):
    distance_travelled = jnp.cumsum(jnp.sqrt(jnp.sum(jnp.diff(trajectories, axis=-1)** 2, axis=-2)), axis=-1)
    distance_travelled = distance_travelled / distance_travelled.max(-1) # normalize between 0 and 1
    T10 = jnp.where(distance_travelled >= 0.1, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
    T90 = jnp.where(distance_travelled >= 0.9, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
    return T10, T90


def calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold):
    trajectories = trajectories[..., 1:] #remove first step (dont wanna take into account big jumps happening in first step)

    # settling time:  first time T such that the distance between y(t) and yfinal ≤ 0.02 × |yfinal – yinit| for t ≥ T
    # normalize such that origin is final point and unit=(end-origin)
    extent = (trajectories.max(-1) - trajectories.min(-1))
    extent = extent.at[extent == 0.].set(1.)
    normalized_trajectories = trajectories / extent[..., jnp.newaxis]
    distance_to_target = jnp.linalg.norm(normalized_trajectories - normalized_trajectories[..., -1][..., jnp.newaxis], axis=1)
    distance_to_target = distance_to_target / distance_to_target[:, 0][:, jnp.newaxis]
    settling_times = vmap(calc_settling_time, in_axes=(0, None))(distance_to_target, settling_time_threshold)

    # travelling time: time it takes for the response to travel from 10% to 90% of the way from yinit to yfinal
    T10s, T90s = vmap(calc_travelling_time)(normalized_trajectories)

    # detours (duration and area)
    detours_duration = []
    detours_area = []
    detours_timesteps = []

    for sample_idx in range(len(distance_to_target)):
        detour_timesteps = []
        detour_duration = 0.
        detour_area = 0.

        if settling_times[sample_idx] > 0:
            cur_distance_to_target = distance_to_target[sample_idx, :settling_times[sample_idx]]
            is_distance_increasing = jnp.concatenate(
                [jnp.array([False]), jnp.diff(cur_distance_to_target) > 0])
            is_distance_decreasing = jnp.concatenate(
                [jnp.array([True]), jnp.diff(cur_distance_to_target) < 0])
            start_detour_timesteps = jnp.where(is_distance_decreasing[:-1] & is_distance_increasing[1:])[0]
            if len(start_detour_timesteps) > 0:
                start_detour_dist_vals = cur_distance_to_target[start_detour_timesteps]
                end_detour_timesteps = []

                for start_detour_timestep, start_detour_dist_val in zip(start_detour_timesteps, start_detour_dist_vals):
                    possible_detour_timesteps = jnp.where((cur_distance_to_target[:-1] >= start_detour_dist_val) &
                                                          (cur_distance_to_target[1:] <= start_detour_dist_val))[0] + 1
                    # take the first time step (after start_detour_timestep) where distance curve is crossing back y=start_detour_dist_val
                    # if no crossing back before settling time, we consider settling time as the end of the detour
                    possible_end_detour_timesteps = possible_detour_timesteps[possible_detour_timesteps > start_detour_timestep]
                    if len(possible_end_detour_timesteps) > 0:
                        end_detour_timestep = possible_end_detour_timesteps[0]
                    else:
                        end_detour_timestep = settling_times[sample_idx]-1
                    end_detour_timesteps.append(end_detour_timestep)

                # calc union of intervals (in case some overlaps due to noise)
                detour_timesteps = jnp.where(jnp.array([(jnp.arange(len(cur_distance_to_target)) >= start) &
                                                        (jnp.arange(len(cur_distance_to_target)) <= end)
                                                        for (start, end) in
                                                        zip(start_detour_timesteps, end_detour_timesteps)]).any(0))[0]
                detour_duration = len(detour_timesteps)

                rel_start_detours_timesteps = jnp.concatenate([jnp.array([0]), jnp.where((detour_timesteps[1:] - detour_timesteps[:-1]) > 1)[0]+1])
                rel_end_detours_timesteps = jnp.roll(rel_start_detours_timesteps-1, -1)

                detour_polygon = Polygon()
                valid_detour_timesteps = jnp.empty((0,))
                for start, end in zip(rel_start_detours_timesteps, rel_end_detours_timesteps):
                    detour_points = normalized_trajectories[sample_idx][:, detour_timesteps[start:end]].transpose()
                    if len(detour_points) >= 3:
                        cur_detour_polygon = Polygon([*detour_points])
                        if cur_detour_polygon.is_valid:
                            detour_polygon = unary_union([detour_polygon, cur_detour_polygon])
                            valid_detour_timesteps = jnp.concatenate([valid_detour_timesteps, detour_timesteps[start:end]])

                detour_area = detour_polygon.area

        detours_timesteps.append(detour_timesteps)
        detours_duration.append(detour_duration)
        detours_area.append(detour_area)   


    trajectories_statistics = DictTree()
    trajectories_statistics.distance_to_target = distance_to_target
    trajectories_statistics.settling_times = settling_times
    trajectories_statistics.T10s = T10s
    trajectories_statistics.T90s = T90s
    trajectories_statistics.detours_timesteps = detours_timesteps
    trajectories_statistics.detours_duration = jnp.array(detours_duration) 
    trajectories_statistics.detours_area = jnp.array(detours_area)

    return trajectories_statistics
deltaT = system_rollout_config.deltaT
settling_time_threshold=0.02
trajectories = eval_system_outputs_library.ys[:, jnp.array(observed_node_ids), :]
trajectories_statistics = calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold)
Figure 20: Trajectories Characteristics.

Robustness tests

# Update the system_rollout to run for a bit longer
eval_system_rollout_config = deepcopy(system_rollout_config)
eval_system_rollout_config.n_secs = int(system_rollout_config.n_secs*1.2) 
eval_system_rollout_config.n_system_steps = int(eval_system_rollout_config.n_secs/eval_system_rollout_config.deltaT)

eval_system_rollout = create_system_rollout_module(eval_system_rollout_config)
batched_eval_system_rollout = vmap(eval_system_rollout, in_axes=(0, None, 0, None, 0))

# perturbation hyperparams
perturbation_min_duration = 50
perturbation_max_duration = 500
T10 = jnp.median((trajectories_statistics.T10s+1))*deltaT
T90 = jnp.median((trajectories_statistics.T90s+1))*deltaT
start = T10
end = min(max(T90, start+perturbation_min_duration), start+perturbation_max_duration)

Evaluating the robustness of the discovered behavioral abilities

test_tasks = {
            "noise_std": [0.001, 0.005, 0.01],
            "noise_period": [10, 5, 1],
            "push_magnitude": [0.05, 0.1, 0.15],
            "push_number": [1, 2, 3],
            "wall_length": [0.05, 0.1, 0.15],
            "wall_number": [1, 2, 3]
        }
def get_perturbation_generator_config(var_name, var_val):
    perturbation_generator_config = Dict()

    if var_name.split("_")[0] == "noise":
        perturbation_generator_config = Dict()
        perturbation_generator_config.perturbation_type = "noise"
        perturbation_generator_config.perturbed_node_ids = random_intervention_generator_config.controlled_node_ids


        if var_name.split("_")[1] == "std":
            n_noises = int((end-start)//test_tasks["noise_period"][1])
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]] #5
            perturbation_generator_config.std = var_val

        elif var_name.split("_")[1] == "period":
            n_noises = int((end - start) // float(var_val))
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]]
            perturbation_generator_config.std = test_tasks["noise_std"][1]

    elif var_name.split("_")[0] == "push":
        perturbation_generator_config = Dict()
        perturbation_generator_config.perturbation_type = "push"
        perturbation_generator_config.perturbed_node_ids = observed_node_ids

        if var_name.split("_")[1] == "magnitude":
            perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
            perturbation_generator_config.magnitude = var_val

        elif var_name.split("_")[1] == "number":
            perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + var_val)[1:-1]]
            perturbation_generator_config.magnitude = test_tasks["push_magnitude"][1]

    elif var_name.split("_")[0] == "wall":
        perturbation_generator_config = Dict()
        perturbation_generator_config.perturbation_type = "wall"
        perturbation_generator_config.wall_type = "force_field"
        perturbation_generator_config.perturbed_intervals = [[0, eval_system_rollout_config.n_secs]]
        perturbation_generator_config.perturbed_node_ids = observed_node_ids
        perturbation_generator_config.walls_sigma = [1e-2, 1e-4]

        if var_name.split("_")[1] == "length":
            perturbation_generator_config.n_walls = 1
            perturbation_generator_config.walls_intersection_window = [[0.1, 0.9]]
            perturbation_generator_config.walls_length_range = [[var_val, var_val]]

        elif var_name.split("_")[1] == "number":
            perturbation_generator_config.n_walls = var_val
            walls_windows_pos = jnp.linspace(0, 1, 2 + var_val)
            walls_spacing = (walls_windows_pos[1] - walls_windows_pos[0]) * 3 / 4
            perturbation_generator_config.walls_intersection_window = [[t - walls_spacing, t + walls_spacing] for t in walls_windows_pos[1:-1]]
            perturbation_generator_config.walls_length_range = [[test_tasks["wall_length"][1], test_tasks["wall_length"][1]]] * var_val

    return perturbation_generator_config
n_perturbations = 3
root_test_save_folder = "data/robustness_tests"

for test_task_var_name, test_task_var_range in test_tasks.items():
    for test_task_var_val in test_task_var_range:
        test_save_folder=f"{root_test_save_folder}/{test_task_var_name}_{test_task_var_val}"
        print(test_save_folder)

        if not os.path.exists(test_save_folder):
            perturbation_generator_config = get_perturbation_generator_config(test_task_var_name, test_task_var_val)
            perturbation_generator, perturbation_fn = create_perturbation_module(perturbation_generator_config)

            run_robustness_tests(jax_platform_name, seed, n_perturbations, test_save_folder,
                                 eval_system_outputs_library, 
                                 eval_intervention_params_library, intervention_fn,
                                 perturbation_generator, perturbation_fn,
                                 eval_system_rollout, null_rollout_statistics_encoder,
                                 out_sanity_check=False, save_modules=False, save_logs=False)
Figure 21: Robustness tests results.
Figure 22: Example robust pathway A->B (gray, prior to perturbation), stressed with various kind of perturbation (dropdown menu). Perturbations with smaller (top), intermediary (middle) and stronger (bottom) intensities/frequencies are shown, for 3 random perturbation generations.

Part 5: Perspectives for reuses of the behavioral catalog

Energy Landscape ⛰

Trajectory-based energy landscapes of gene regulatory networks paper.

    
imgep_trajectories = imgep_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(imgep_trajectories).any(-1).any(-1)) & (imgep_trajectories>=-1e-6).all(-1).all(-1)
imgep_points = jnp.concatenate(imgep_trajectories[is_valid_bool], axis=-1)

rs_trajectories = rs_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(rs_trajectories).any(-1).any(-1)) & (rs_trajectories>=-1e-6).all(-1).all(-1)
rs_points = jnp.concatenate(rs_trajectories[is_valid_bool], axis=-1)


imgep_perturbed_points = []
for task_name in test_tasks.keys():
    for task_var in test_tasks[task_name]:
        test_experiment_history = DictTree.load(os.path.join(f"{root_test_save_folder}/{task_name}_{task_var}", "history.pickle"))
        perturbed_trajectories = test_experiment_history.system_output_library.ys[:, :, observed_node_ids, :].reshape(len(eval_system_outputs_library.ys)*n_perturbations, 2, -1)
        is_valid_bool = ~(jnp.isnan(perturbed_trajectories).any(-1).any(-1)) & (perturbed_trajectories>=-1e-6).all(-1).all(-1)
        imgep_perturbed_points.append(jnp.concatenate(perturbed_trajectories[is_valid_bool], axis=-1))
imgep_perturbed_points = jnp.concatenate(imgep_perturbed_points, axis=-1)

all_points = jnp.concatenate([imgep_points, rs_points, imgep_perturbed_points], axis=-1)
ymin, ymax = all_points.min(-1), all_points.max(-1)
del all_points 


results = {}
for k, points in zip(["(a) Random Search", "(b) IMGEP", " (c) IMGEP+perturbations"], [rs_points, imgep_points, imgep_perturbed_points]):

    H, xedges, yedges = jnp.histogram2d(
        x=points[0, :],
        y=points[1, :],
        bins=10,
        range=jnp.stack([ymin, ymax]).transpose()
    )
    H = H.transpose()

    # Compute probability distribution P
    H = H.at[jnp.where(H == 0)].set(1)
    U = -jnp.log(H / H.sum())

    # Calculate energy Landscape
    bin_sizex = xedges[1] - xedges[0]
    bin_sizey = yedges[1] - yedges[0]
    x = xedges[:-1] + bin_sizex / 2
    y = yedges[:-1] + bin_sizey / 2
    z = U.flatten()
    interp = RegularGridInterpolator((x,y), U, method="cubic", bounds_error=False)
    xi = jnp.linspace(ymin[0], ymax[0], 100)
    yi = jnp.linspace(ymin[1], ymax[1], 100)
    xi, yi = jnp.meshgrid(xi, yi)
    zi = interp((xi, yi)).transpose()

    # save results
    results[k] = (xi, yi, zi)
Figure 23: Trajectory-based energy landscapes constructed from the different set of discoveries: from random search (left), imgep search (second left), robustness tests (right).

Therapeutic Pespectives 💊

# Convert system rollout in batch mode
batched_system_rollout = vmap(system_rollout, in_axes=(0, None, 0, None, 0), out_axes=(0,None))

# Get most robust trajectories
mean_sensitivities = jnp.nanmean(jnp.stack(list(data.values())), axis=0)
Q25 = jnp.percentile(mean_sensitivities, 25)
rob_sample_ids = jnp.where(mean_sensitivities < Q25)[0]

# "healthy" and "disease" regions polygons
regions_poly = {}
for label_idx, region_name in zip([1,2], ("healthy", "disease")):
    cluster_point_ids = jnp.where(imgep_clusters_labels==label_idx)[0]
    z_points = imgep_reached_endpoints[cluster_point_ids]
    eps=0.05
    poly = unary_union([Point(point).buffer((eps)) for point in z_points])
    poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
    poly = affinity.scale(poly, xfact=(analytic_bc_space_high-analytic_bc_space_low)[0], yfact=(analytic_bc_space_high-analytic_bc_space_low)[1], origin=(0,0,0))
    poly = affinity.translate(poly, xoff=analytic_bc_space_low[0], yoff=analytic_bc_space_low[1])
    regions_poly[region_name] = poly

# Select pathways in "disease" region
disease_pathways_ids = []
for sample_idx in rob_sample_ids:
    if regions_poly["disease"].contains(Point(eval_system_outputs_library.ys[sample_idx, observed_node_ids, -1])):
        disease_pathways_ids.append(sample_idx)
disease_pathways_ids = jnp.array(disease_pathways_ids)

# Target endpoint: centroid of "healthy region"
target_endpoint = jnp.array(regions_poly["healthy"].centroid.coords)
print(f"Target: {target_endpoint}")


# Create intervention function 
class CustomInterventionFn(grn.PiecewiseIntervention):
    def apply(self, key, y, y_, w, w_, c, c_, interval_idx, intervention_params):
        y = lax.cond(interval_idx.sum() == 0, self.apply_init, self.apply_clamping, y, interval_idx, intervention_params)
        return y, w, c

    def apply_init(self, y, interval_idx, intervention_params):
        return intervention_params.y0

    def apply_clamping(self, y, interval_idx, intervention_params):
        for y_idx, clamp_vals in intervention_params.clamping.items():
            y = y.at[y_idx].set(clamp_vals[interval_idx-1])
        return y

controlled_node = "MEKPP"
controlled_intervals = [[i, i+10] for i in range(10,100,10)]
custom_intervention_fn = CustomInterventionFn(grn.TimeToInterval(intervals=[[-deltaT/2., deltaT/2.]]+controlled_intervals))

# Perurbations functions (similar than in Part 1)
batched_noise_perturbation_generator = vmap(noise_perturbation_generator)
push_perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
push_perturbation_generator, _ = create_perturbation_module(push_perturbation_generator_config)
batched_push_perturbation_generator = vmap(push_perturbation_generator)
batched_wall_perturbation_generator = vmap(wall_perturbation_generator)

print(f"We search a stepwise intervention applied on {controlled_node} and on time intervals {controlled_intervals}, that will allow to get unstuck from 'disease' to 'healthy' region under a various perturbations")

# Define train loss: distance to center of "healthy" area (we want to achive a state with low ERK and high RKIPP_RP concentrations)
def pathway_evaluate_worker_fn(key, batched_intervention_params, intervention_fn, batched_system_rollout, observed_node_ids, target_endpoint):

    # rollout the system without perturbation
    key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
    batched_system_outputs, _  = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, grn.NullIntervention(), None)

    # rollout the system with noise perturbation
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_noise_perturbation_params, _ = batched_noise_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
    key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
    batched_system_outputs_noise, _  = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, noise_perturbation_fn, batched_noise_perturbation_params)
    loss = jnp.sqrt(jnp.square((batched_system_outputs_noise.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())

    # rollout the system with push perturbation
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_push_perturbation_params, _ = batched_push_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
    key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
    batched_system_outputs_push, _  = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, push_perturbation_fn, batched_push_perturbation_params)
    loss += jnp.sqrt(jnp.square((batched_system_outputs_push.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())

    # rollout the system with wall perturbation
    key, *subkeys = jrandom.split(key, num=batch_size+1)
    batched_wall_perturbation_params, _ = batched_wall_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
    key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
    batched_system_outputs_wall, _  = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, wall_perturbation_fn, batched_wall_perturbation_params)
    loss += jnp.sqrt(jnp.square((batched_system_outputs_wall.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())


    return loss.mean(),  [batched_system_outputs, batched_system_outputs_noise, batched_system_outputs_push, batched_system_outputs_wall], batched_wall_perturbation_params

pathway_evaluate_worker_fn = jit(jtu.Partial(pathway_evaluate_worker_fn, intervention_fn=custom_intervention_fn, 
                                             batched_system_rollout=batched_system_rollout, observed_node_ids=observed_node_ids, target_endpoint=target_endpoint))


# Optimize intervention (with random search)
n_trials = 20
batch_size = len(disease_pathways_ids) 
best_loss = 1e7

for trial_idx in range(n_trials):

    # Generate random intervention on (Raf1, MEKPP) 
    batched_interventions_params = DictTree()
    # Set last state of disease pathways as init state (via intervention)
    batched_interventions_params.y0 = eval_system_outputs_library.ys[disease_pathways_ids, :, -1]
    y_idx = system_rollout.grn_step.y_indexes[controlled_node]
    key, subkey = jrandom.split(key)
    batched_interventions_params.clamping[y_idx] = jnp.tile(jrandom.uniform(subkey, shape=(len(controlled_intervals),),
                                                                  minval=0.05*eval_system_outputs_library.ys[disease_pathways_ids,y_idx, -1].min(), 
                                                                  maxval=20*eval_system_outputs_library.ys[disease_pathways_ids,y_idx, -1].max()),
                                                            (batch_size, 1))
    # Calc loss
    key, subkey = jrandom.split(key)
    trial_loss, batched_outputs, wall_params = pathway_evaluate_worker_fn(subkey, batched_interventions_params)

    print(trial_loss)
    if trial_loss < best_loss: 
        best_loss = trial_loss
        best_params = batched_interventions_params
        best_outputs = batched_outputs
        display_wall_params = wall_params
Figure 24: (a) 10 most robust identified pathways (average sensitivity <0.05) are displayed. We can see that most of them converge toward attractors in the "disease" region (orange). (b) Example stepwise intervention on MEKPP, found with simple random search, that we apply on states stuck in the "disease" region during 100 seconds.(c) The discovered intervention successfully brings back all points from "disease" region closer to the target endpoint in the "healthy" region (green), and this under various tested perturbations (not shown here).
Figure 25: Individual results after applying the MEKPP interventions, and when adding perturbations

👉 Our behavioral catalog allows to identify robust reachable goal states, in particular here to undesired "disease" states (which were not identified with random search) and to develop interventions allowing to robustly reset them