Authors | Affiliation | Published |
---|---|---|
Mayalen Etcheverry | INRIA, Flowers team, Poietis | September, 2023 |
Clément Moulin-Frier | INRIA, Flowers team | |
Pierre-Yves Oudeyer | INRIA, Flowers team | |
Michael Levin | The Levin Lab, Tufts University | Reproduce in Notebook |
TL;DR
This tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks.
It is intented to walk you through the set of tools we use to
(i) automatically explore the space of input stimulis of a given biological network in order to discover a diversity of behaviors in these systems
(ii) analyze the robustness of the discovered abilities in order to infer the network's navigation competencies in the transcriptional space of gene activation
(iii) explore possible reuses of the constructed "behavioral catalog" in a biomedical context
Throughout this tutorial, we will be using the AutoDiscJax library, a library built on top of jax and equinox to facilitate automated experimentation and simulation of biological network pathways.
AutoDiscJax follows two main design principles:
1) Everything is a module, where a module is simply a parametrized function that takes inputs and returns outputs (and log_data). All autodiscjax modules adx.Module
are implemented as equinox modules eqx.Module
, which essentially allows to represent the function as a callable PyTree (and hence to be compatible with jax transformations) while keeping an intuitive API for model building (python class with a _call_ method). The only add-on with respect to equinox is that when instantiating a adx.Module
, the user must specify the module's outputs PyTree structure, shape and dtype.
2) An experiment pipeline defines (i) how modules interact sequentially and exchange information, and (ii) what information should be collected and saved in the experiment history.
AutoDiscJax provides a handful of already-implement modules and pipelines to 1) Simulate biological networks while intervening on them according to our needs (see Part 1) 2) Automatically organize experimentation in those systems, by implementing a variety of exploration approaches such as random, optimization-driven and curiosity-driven search (see Part 2 and Part 3) 3) Analyze the discoveries of the exploration method, for instance by testing their robustness to various perturbations (see Part 4)
Finally, AutoDiscJax takes advantage of JAX mains features (just-in-time compilation, automatic vectorization and automatic differentation) which are especially advantageous for parallel experimentation and computational speedups, as well as gradient-based optimization.
Ordinary differential equations (ODE) models are widely used to represent the behavior of complex biological processes. These ODE models are often experimentally determined as well as curated by biologists, e.g. combining experimental and data-drive methods to determine whether and how pairs of biomolecules interact. The resulting mathematical models are usually stored and exchanged using the Systems Biology Markup Language (SBML) language. Thanks to community efforts, large collections of published ODE models have been made publicly available on online databases, such as the BioModels website.
For this tutorial, we will be studying the gene regulatory network (GRN) model that describes the influence of RKIP on the ERK Signaling Pathway, that is described by this Cho et al.'s paper, and that is hosted on the BioModels database with the BIOMD0000000647 identifier and described by the following reaction graph:
Whereas the SBML file provides information about the different species, parameters and reactions involved in this model, we must use our own tools for carrying out simulation studies.
In the first part of this tutorial, we will be showing how to use the SBMLtoODEjax library to automatically retrieve and parse the SBML file into a python file, and the AutoDiscJax library to easily simulate the model and manipulate it based on our needs.
biomodel_idx = 647
observed_node_names = ["ERK", "RKIPP_RP"]
Let's first use the SBMLtoODEjax library to download the SBML file and generate the corresponding python class that will allow us to simulate the model's dynamics.
out_model_sbml_filepath = f"data/biomodel_{biomodel_idx}.xml"
out_model_jax_filepath = f"data/biomodel_{biomodel_idx}.py"
# Donwload the SBML file
if not os.path.exists(out_model_sbml_filepath):
model_xml_body = sbmltoodejax.biomodels_api.get_content_for_model(biomodel_idx)
with open(out_model_sbml_filepath, 'w') as f:
f.write(model_xml_body)
# Generation of the python class from the SBML file
if not os.path.exists(out_model_jax_filepath):
model_data = sbmltoodejax.parse.ParseSBMLFile(out_model_sbml_filepath)
sbmltoodejax.modulegeneration.GenerateModel(model_data, out_model_jax_filepath)
At this point, you should have created a biomodel_647.py
file that contains several variables and functional modules as follows:
The variables are instantiated as such as for many models the kinematic parameters are all constant (w$=\emptyset$) but for others, there are some kinematic parameters w(t) that evolve in time in addition to the constant ones c.
For those who are interested to undersant more about the specifications of SBML files and SBMLtoODEjax files, we refer to the SBMLtoODEjax documentation: https://developmentalsystems.org/sbmltoodejax/design_principles.html
Ok, let's now create our first module: the System Rollout. In short, the system rollout can be seen as a wrapper of the previously-created smbltoodejax's ModelRollout module, but this time allowing to apply all sorts of interventions and/or perturbations during the rollout simulation, and without having to modify the rollout codebase. We'll see how to do that in short, but first let's create our system rollout module and see how a simulation looks like.
When instancing the ModelStep module, we should specify the desired time step $\Delta t$ and ODE solver parameters $(atol, rtol, mxstep)$
# System Rollout Config
system_rollout_config = Dict()
system_rollout_config.system_type = "grn"
system_rollout_config.model_filepath = out_model_jax_filepath # path of model class that we just created using sbmltoodejax
system_rollout_config.atol = 1e-6 # parameters for the ODE solver
system_rollout_config.rtol = 1e-12
system_rollout_config.mxstep = 1000
system_rollout_config.deltaT = 0.1 # the ODE solver will compute values every 0.1 second
system_rollout_config.n_secs = 1000 # number of a seconds of one rollout in the system
system_rollout_config.n_system_steps = int(system_rollout_config.n_secs/system_rollout_config.deltaT) # total number of steps returned after a rollout
# Create the module
system_rollout = create_system_rollout_module(system_rollout_config)
# Get observed node ids
observed_node_ids = [system_rollout.grn_step.y_indexes[observed_node_names[0]], system_rollout.grn_step.y_indexes[observed_node_names[1]]]
Let's now simulate the module. By default, the rollout starts with the initial concentrations as given in the SBML file.
key, subkey = jrandom.split(key)
default_system_outputs, log_data = system_rollout(subkey) # all autodiscjax modules takes the state of the "random seed" as input (needed e.g. for operating random sampling operations within the module)
Here is what we obtain when simulating the network with the default initial condtions (see original Figure 5 of Cho el al's paper):
We can also observe the resulting trajectories in phase space, also called transcriptional phase for gene regulatory networks. Let's for instance say that we are interested in observing the activation response of two specific nodes: ERK and RKIPP_RP. By plotting their trajectory in the transcriptional space, we can see that they start in point A(0,0) and navigate until they reach a steady state in point B(0.036,0.055).
We would now like to do several interventions on the system rollout to see how this influence the trajectory of the network in transcriptional space. The system rollout module allows us to do it by specifying an intervention function $(y^-, y, w^-, w, c^-, c, t^-, t) \mapsto (y, w, c, t)$ that is called at everystep during a rollout, where $(y^-,w^-,c^-,t^-)$ represent the variables values before the step (t-1) and $(y,w,c,t)$ represent the variable values that would be returned by the ModelStep function without intervention (and that are overwrited by the intervention function).
Several intervention functions are already provided in AutoDiscJax.
In particular, the PiecewiseIntervention
module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints.
The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention
or by adding to the variable's current value PiecewiseAddConstantIntervention
).
Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X
# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_values = [[0.03], [0.02]]
intervention_fn_1 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_1 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
node_idx = system_rollout.grn_step.y_indexes[node_name]
intervention_params_1.y[node_idx] = jnp.array(node_value)
# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_1, intervention_params=intervention_params_1)
👉 Interestingly we can see that despite enforcing the network to start at another point A', the network still converges to the same point B. Moreover, instead of directly going from A' to B, the network seems to do a "detour" to reach back its initial trajectory, and then follow again the trajectory succesfully until point B.
# Create the intervention
controlled_intervals = [[0, 10], [400, 410]]
controlled_node_names = ["MEKPP"]
controlled_node_values = [[2.5, 1.0]]
intervention_fn_2 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_2 = DictTree() # A DictTree is a Dict container (i.e. dictionnary where items can be get and set like attributes) that is registered as a Jax PyTree
for (node_name, node_value) in zip(controlled_node_names, controlled_node_values):
node_idx = system_rollout.grn_step.y_indexes[node_name]
intervention_params_2.y[node_idx] = jnp.array(node_value)
# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_2, intervention_params=intervention_params_2)
👉 Here, the clamping of MEKPP seems to have some effect on ERK but not on RKIPP_RP. Indeed, after the initial clamping of MEKPP to 2.5 (during 10 seconds), the trajectory of the ERK-RKIPP_RP pair still follows a very similar S-shape curve, and arrives close to the original B point but with slight lower ERK expression level (t=400). From the moment we re-clamp MEKPP to lower activation (1.0 for 10 seconds at t=400), we see an effect on ERK expression level where the final steady state B' gets shifted to the right.
# Create the intervention
controlled_intervals = [[0, system_rollout_config.deltaT/2.0]]
controlled_param_names = ["k5"]
controlled_param_values = [[0.1]]
intervention_fn_3 = grn.PiecewiseSetConstantIntervention(time_to_interval_fn=grn.TimeToInterval(intervals=controlled_intervals))
intervention_params_3 = DictTree()
for (param_name, param_value) in zip(controlled_param_names, controlled_param_values):
param_idx = system_rollout.grn_step.c_indexes[param_name] #this time we specify intervention parameter value for the key "c"
intervention_params_3.c[param_idx] = jnp.array(param_value)
print(f"Initial {param_name} param value: {system_rollout.c[param_idx]}, changed to {param_value}")
# Run the system with the intervention
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn_3, intervention_params=intervention_params_3)
👉 Here we can see changing the parameter k5 shifts the trajectory end point quite significantly from B to B', but that qualitatively the trajectory seems to keep a similar "S" shape.
Similarly, we would like to apply several perturbations on the system rollout to see how this influence the trajectory of the network in transcriptional space. In AutoDiscJax, a perturbation is implemented in the same manner than an intervention (by specifying a perturbation function that is also called at everystep during a rollout). The difference is only conceptual, as a perturbation is supposed to represent uncontrolled events (such as noise or other environmental stresses) whereas an intervention is supposed to represent controlled events (aka applied by the experimenter, such as drug stimuli).
Several perturbation functions are already provided in AutoDiscJax.
In particular, the PiecewiseIntervention
module allows to control the variables of the network with a piecewise-defined function, which allows to apply a broad range of interventions based on one's needs and constraints.
The module configuration allows us to specify which variable(s) (stored in y, w or c) to intervene on, when (on which time intervals), and how (e.g. by setting the variable value using PiecewiseSetConstantIntervention
or by adding to the variable's current value PiecewiseAddConstantIntervention
).
Let's try it and see how our ERK-pathway network will react to the following interventions: 1) Start the trajectory in ERK-RKIPP_RP space from another point A'(0.03, 0.02) 2) clamp node RKIP to 1.0 during 10 first seconds and then to 0.1 during 10 additional seconds at t=400 3) change the value of the kinematic parameter X from X to X.
# we apply noise between t=0 and t=80 secs to ALL node
start = 0
end = 80
deltaT = system_rollout.deltaT
perturbed_node_ids = list(range(len(system_rollout.grn_step.y_indexes)))
# Create the noise perturbation generator module
noise_perturbation_generator_config = Dict()
noise_perturbation_generator_config.perturbation_type = "noise"
noise_perturbation_generator_config.perturbed_intervals = [[t-deltaT/2, t+deltaT/2] for t in jnp.linspace(start, end, 2 + int((end-start)/5))[1:-1]] #add noise every 5 sec between start and end
noise_perturbation_generator_config.perturbed_node_ids = perturbed_node_ids
noise_perturbation_generator_config.std = 0.01
noise_perturbation_generator, noise_perturbation_fn = create_perturbation_module(noise_perturbation_generator_config)
# Run the system with the perturbation
key, subkey = jrandom.split(key)
noise_perturbation_params, log_data = noise_perturbation_generator(subkey, default_system_outputs)
noise_system_outputs, log_data = system_rollout(subkey, None, None,
noise_perturbation_fn, noise_perturbation_params)
push_t = 3
# Create the push perturbation generator module
push_perturbation_generator_config = Dict()
push_perturbation_generator_config.perturbation_type = "push"
push_perturbation_generator_config.perturbed_intervals = [[push_t-deltaT/2, push_t+deltaT/2]]
push_perturbation_generator_config.perturbed_node_ids = observed_node_ids
push_perturbation_generator_config.magnitude = 0.1
push_perturbation_generator, push_perturbation_fn = create_perturbation_module(push_perturbation_generator_config)
# Run the system with the perturbation
key, subkey = jrandom.split(key)
push_perturbation_params, log_data = push_perturbation_generator(subkey, default_system_outputs)
push_system_outputs, log_data = system_rollout(subkey, None, None,
push_perturbation_fn, push_perturbation_params)
# Create the wall perturbation generator module
wall_perturbation_generator_config = Dict()
wall_perturbation_generator_config.perturbation_type = "wall"
wall_perturbation_generator_config.wall_type = "force_field"
wall_perturbation_generator_config.perturbed_intervals = [[0, system_rollout_config.n_secs]]
wall_perturbation_generator_config.perturbed_node_ids = observed_node_ids
wall_perturbation_generator_config.n_walls = 2
wall_perturbation_generator_config.walls_intersection_window = [[0.1, 0.15], [0.85, 0.9]] # in distance travelled from 0 (start point A) to 1.0 (end point B)
wall_perturbation_generator_config.walls_length_range = [[0.1, 0.1], [0.1, 0.1]]
wall_perturbation_generator_config.walls_sigma = [1e-2, 1e-4]
wall_perturbation_generator, wall_perturbation_fn = create_perturbation_module(wall_perturbation_generator_config)
# Run the system with the perturbation
key, subkey = jrandom.split(key)
wall_perturbation_params, log_data = wall_perturbation_generator(subkey, default_system_outputs)
wall_system_outputs, log_data = system_rollout(subkey, None, None,
wall_perturbation_fn, wall_perturbation_params)
👉 Again robust
To finish Part 1 of this tutorial, let's see how AutoDiscJax allows us to perform simulations in parallel.
# Put the system in batch mode
batched_system_rollout = vmap(system_rollout, in_axes=(0, None, 0))
# Create the M=10 interventions (vector of starting positions between minval and maxval)
M = 10
controlled_node_names = ["ERK", "RKIPP_RP"]
controlled_node_minvals = [0.02, 0.02]
controlled_node_maxvals = [0.08, 0.08]
batched_interventions_params_1 = DictTree()
for (node_name, node_minval, node_maxval) in zip(controlled_node_names, controlled_node_minvals, controlled_node_maxvals):
node_idx = system_rollout.grn_step.y_indexes[node_name]
key, subkey = jrandom.split(key)
batched_interventions_params_1.y[node_idx] = jrandom.uniform(subkey, shape=(M, 1), minval=node_minval, maxval=node_maxval)
key, *subkeys = jrandom.split(key, num=M + 1)
batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn_1, batched_interventions_params_1)
print(default_system_outputs.ys.shape, batched_system_outputs.ys.shape)
Using jax vmap transformation, the above code automatically vectorizes the call to the system rollout module over different intervention parameters and stores the vectorized results in the batched_system_outputs output variable. This is very a convenient (and fast) way to test several interventions in the biological network.
This plot shows the M=10 resulting trajectories obtained in the transcriptional space when applying intervention1 for 10 different starting conditions $A0, \dots, A9$.
👉 We can see that despite starting the simulations in 10 different positions (initial amounts of ERK and RKOPP_RP), they all converge to the same steady state point B.
Automated Experimentation: plug an explorer module that generates sequence of interventions given a budget N. Many reasons why we would like to do so (steady state analysis, scientific discovery, drug development, etc...)
Challenge 1: limited budget of experiments yet high dimensional parameter space
batch_size = 100
random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_node_ids = list(range(len(default_system_outputs.ys)))
random_intervention_generator_config.controlled_intervals = [[0, system_rollout.deltaT/2.0]]
random_search_discoveries = {}
for r in [1, 10, 100]:
controlled_node_minvals = default_system_outputs.ys.min(-1)/r
controlled_node_maxvals = default_system_outputs.ys.max(-1)*r
intervention_params_tree = DictTree()
for y_idx in random_intervention_generator_config.controlled_node_ids:
intervention_params_tree.y[y_idx] = "placeholder"
random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),),
intervention_params_tree)
random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
random_intervention_generator_config.low = DictTree()
random_intervention_generator_config.high = DictTree()
for (node_idx, node_minval, node_maxval) in zip(random_intervention_generator_config.controlled_node_ids, controlled_node_minvals, controlled_node_maxvals):
random_intervention_generator_config.low.y[node_idx] = jnp.array([node_minval])
random_intervention_generator_config.high.y[node_idx] = jnp.array([node_maxval])
random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
batched_random_intervention_generator = vmap(random_intervention_generator)
# Generate random interventions (batch mode)
key, *subkeys = jrandom.split(key, num=batch_size+1)
batched_interventions_params, log_data = batched_random_intervention_generator(jnp.array(subkeys))
# Rollout the system (batch mode)
key, *subkeys = jrandom.split(key, num=batch_size+1)
batched_system_outputs, log_data = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_interventions_params)
# Store the discoveries
random_search_discoveries[r] = batched_system_outputs
👉 with too constrained range (r=1), random search misses a big part of what is feasible (in terms of reachable transcriptional space). Choosing a looser input parameter space range (e.g. r=10), shows that it is possible to find novel steady states (e.g. with ERK > 5). However, this also means the the exploration space is bigger and hence harder to explore. In fact, we can see that for r=10 and r=100, random search is not very efficient as most of the discoveries are localized on the ERK=0 axis (i.e. fall in the the ERK=0 valley/attractor).
# Define loss function (L2 distance to target point)
def evaluate_worker_fn(key, intervention_params, intervention_fn, system_rollout, observed_node_ids, target_point, low, high):
# rollout the system with parameters
key, subkey = jrandom.split(key)
system_outputs, log_data = system_rollout(subkey, intervention_fn, intervention_params)
# Get trajectory final point
reached_point = system_outputs.ys[jnp.array(observed_node_ids), -1]
# Calc L2 distance to target point
loss = jnp.sqrt((jnp.square((reached_point - target_point)/(high-low))).sum())
# Append info to log data
log_data = DictTree()
log_data.reached_point = reached_point
log_data.loss = loss
return loss, log_data
previously_reached_points = random_search_discoveries[100].ys[:, jnp.array(observed_node_ids), -1]
low = jnp.nanmin(previously_reached_points, axis=0)
high = jnp.nanmax(previously_reached_points, axis=0)
target_point_1 = jnp.array([200., 4.])
evaluate_worker_fn_1 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout,
observed_node_ids=observed_node_ids, target_point=target_point_1, low=low, high=high)
target_point_2 = jnp.array([200., 10.])
evaluate_worker_fn_2 = jtu.Partial(evaluate_worker_fn, intervention_fn=intervention_fn, system_rollout=system_rollout, observed_node_ids=observed_node_ids, target_point=target_point_2, low=low, high=high)
# Create SGD Optimizer
intervention_optimizer_config = Dict()
intervention_optimizer_config.n_optim_steps = 100
intervention_optimizer_config.n_workers = 1
intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda node: 0.0, random_intervention_generator.low)
intervention_optimizer_config.lr = jtu.tree_map(lambda low, high: 0.01*(high-low),
random_intervention_generator.low, random_intervention_generator.high)
optimizer = optimizers.SGDOptimizer(random_intervention_generator.out_treedef,
random_intervention_generator.out_shape,
random_intervention_generator.out_dtype,
random_intervention_generator.low,
random_intervention_generator.high,
intervention_optimizer_config.n_optim_steps,
intervention_optimizer_config.n_workers,
intervention_optimizer_config.init_noise_std,
intervention_optimizer_config.lr
)
# Start position 1
start_intervention_params_1 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
start_intervention_params_1.y[node_idx] = default_system_outputs.ys[node_idx, 0][jnp.newaxis]
# Start position 2
selected_intervention_ids, distances = nearest_neighbors(target_point_1/(high-low), previously_reached_points/(high-low), k=1)
start_intervention_params_2 = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
start_intervention_params_2.y[node_idx] = random_search_discoveries[100].ys[selected_intervention_ids[0], node_idx, 0][jnp.newaxis]
# Optimization Run 1
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_1, log_data_1 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_1)
# Optimization Run 2
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G1 {target_point_1}")
key, subkey = jrandom.split(key)
optimized_intervention_params_2, log_data_2 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_1)
# Optimization Run 3
print(f"Start position B1 {default_system_outputs.ys[jnp.array(observed_node_ids), -1]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_3, log_data_3 = optimizer(subkey, start_intervention_params_1, evaluate_worker_fn_2)
# Optimization Run 4
print(f"Start position B2 {previously_reached_points[selected_intervention_ids][0]} toward target position G2 {target_point_2}")
key, subkey = jrandom.split(key)
optimized_intervention_params_4, log_data_4 = optimizer(subkey, start_intervention_params_2, evaluate_worker_fn_2)
👉 Interestingly we can see that all optimization runs make progress (training loss decreases) but following specific curves in the transcriptional space (not straight toward the targets). Those curves follow the valleys of the optimization landscape, and we can clearly understand how the choice of the starting point will condition the optimization success, as the optimization run might get stuck in a valley or in a local minima.
For instance here, the default initial point B1 (as given in the SBML file) fails to achieve both targets (G1 and G2): his training loss is reaching a plateau (blue and green curve) and the optimization gets stuck, or at least is making very slow progress (as can be seen by the blue and green trajectories in transcriptional space). Another startint point $B_2$, that was previously found by random search (and selected because it was the closest to G1 among all discoveries), successfully manages to get very close to the target goal G1 (orange trajectory). However, it fails to reach another target point G2 (red trajectory).
👉 This shows the importance of having a good pool of initial discoveries for reaching desired targets with optimization. Because random search is not very efficient in covering the map of possible steady states, it is likely that optimization will fail for many possible targets G. Can we find a more efficient way to populate the pool of discoveries, given the same experimental budget?80
# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)
print(jtu.tree_map(lambda node: node.shape, intervention_params))
# example in batch mode
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
interventions_params, log_data = batched_random_intervention_generator(jnp.array(subkeys))
print(jtu.tree_map(lambda node: node.shape, interventions_params))
goal_embedding_encoder_config = Dict()
goal_embedding_encoder_config.encoder_type = "filter"
goal_embedding_tree = "placeholder"
goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids), ), goal_embedding_tree)
goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: system_outputs.ys[..., observed_node_ids, -1])
goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)
batched_goal_embedding_encoder = vmap(goal_embedding_encoder)
# example: encode the default system outputs
key, subkey = jrandom.split(key)
reached_goal_embedding, log_data = goal_embedding_encoder(subkey, default_system_outputs)
print(reached_goal_embedding)
# example in batch mode: encode system outputs discovered by random search
key, *subkeys = jrandom.split(key, num=batch_size+ 1)
reached_goals_embeddings, log_data = batched_goal_embedding_encoder(jnp.array(subkeys), random_search_discoveries[100])
print(reached_goals_embeddings.shape)
goal_achievement_loss_config = Dict()
goal_achievement_loss_config.loss_type = "L2"
goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)
batched_goal_achievement_loss = vmap(goal_achievement_loss)
# example
target_goal_embedding = target_point_1
key, subkey = jrandom.split(key)
gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
print(gc_loss)
# example in batch mode
target_goals_embeddings = jnp.tile(target_goal_embedding[jnp.newaxis], (batch_size, 1))
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
gc_losses, log_data = batched_goal_achievement_loss(jnp.array(subkeys), reached_goals_embeddings, target_goals_embeddings)
print(gc_losses.shape)
goal_generator_config = DictTree()
goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
goal_generator_config.low = 0.0
goal_generator_config.high = None
goal_generator_config.generator_type = "hypercube"
goal_generator_config.hypercube_scaling = 1.3
goal_generator = create_goal_generator_module(goal_generator_config)
batched_goal_generator = vmap(goal_generator, in_axes=(0, None, None))
# example
key, subkey = jrandom.split(key)
next_target_goal, log_data = goal_generator(subkey, target_goals_embeddings, reached_goals_embeddings)
print(next_target_goal)
# example in batch mode
key, *subkeys = jrandom.split(key, num= batch_size+ 1)
next_target_goals_embeddings, log_data = batched_goal_generator(jnp.array(subkeys), target_goals_embeddings, reached_goals_embeddings)
print(next_target_goals_embeddings.shape)
gc_intervention_selector_config = Dict()
gc_intervention_selector_config.selector_type="nearest_neighbor"
gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
gc_intervention_selector_config.k = 1
gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)
batched_gc_intervention_selector = vmap(gc_intervention_selector, in_axes=(0, 0, None))
# example
key, *subkeys = jrandom.split(key, num=batch_size + 1)
source_interventions_ids, log_data = batched_gc_intervention_selector(jnp.array(subkeys), next_target_goals_embeddings, reached_goals_embeddings)
print(source_interventions_ids.shape)
gc_intervention_optimizer_config = Dict()
gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
gc_intervention_optimizer_config.low = random_intervention_generator.low
gc_intervention_optimizer_config.high = random_intervention_generator.high
gc_intervention_optimizer_config.optimizer_type = "EA"
gc_intervention_optimizer_config.n_optim_steps = 1
gc_intervention_optimizer_config.n_workers = 1
gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low),
gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)
gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
intervention_fn=intervention_fn, system_rollout=system_rollout,
goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
rollout_statistics_encoder=null_rollout_statistics_encoder
)
batched_gc_intervention_optimizer = vmap(partial_gc_intervention_optimizer, in_axes=(0, 0, 0, None))
# example: optimize init species amount (from default ones) to reach a ERK-RKIPP_RP steady state B (150,5)
key, subkey = jrandom.split(key)
optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, start_intervention_params_1, target_point_1, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))
# example in batch mode: optimize the selected interventions (closest to target) toward their respective targets
previous_interventions_params = DictTree()
for node_idx in random_intervention_generator_config.controlled_node_ids:
previous_interventions_params.y[node_idx] = random_search_discoveries[100].ys[:, node_idx, 0]
start_interventions_params = jtu.tree_map(lambda x: x[source_interventions_ids], previous_interventions_params)
key, *subkeys = jrandom.split(key, num=batch_size + 1)
optimized_interventions_params, log_data = batched_gc_intervention_optimizer(jnp.array(subkeys), start_interventions_params, next_target_goals_embeddings, reached_goals_embeddings)
print(jtu.tree_map(lambda node: node.shape, optimized_interventions_params))
👉 Progress toward G3, G7. No progress toward G9 (already quite close to it), but even if new point Z9 if further from G9 it falls in an uncovered area and will be useful for future goals.
👉 with this simple example we can grasp already why the IMGEP will be much more efficient in finding diverse possible final states
# Run IMGEP
jax_platform_name = "cpu"
seed = 0
n_random_batches = 2
n_imgep_batches = 8
batch_size = 20
imgep_experiment_data_save_folder = "data/imgep_data"
if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
imgep_experiment_data_save_folder,
random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
goal_generator, gc_intervention_selector, gc_intervention_optimizer,
goal_embedding_encoder, goal_achievement_loss,
out_sanity_check=False, save_modules=False, save_logs=False)
# Run Random Search
rs_experiment_data_save_folder = "data/rs_data"
if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size,
rs_experiment_data_save_folder,
random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
out_sanity_check=False, save_modules=False, save_logs=False)
imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
print(imgep_reached_goals_embeddings.shape)
rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle"))
rs_reached_goals_embeddings = rs_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]
print(rs_reached_goals_embeddings.shape)
Intrinsically motivated goal exploration algorithms are designed to autonomously discover the widest range of possible diverse effects that can be produced in an initially unknown system (here our biological network). Thus, a first way to evaluate the exploration algorithm is to measure how well and how fast they cover the state space, and particularly in comparison to random search.
epsilon = 0.033
analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))
def calc_analytic_bc_coverage(reached_endpoints, epsilon):
for step_idx, reached_endpoint in enumerate(reached_endpoints):
if step_idx == 0:
union_polygon = Point(reached_endpoints[0]).buffer(epsilon)
covered_areas = [union_polygon.area]
else:
union_polygon = unary_union([union_polygon, Point(reached_endpoint).buffer(epsilon)])
covered_areas.append(union_polygon.area)
return union_polygon, covered_areas
imgep_reached_endpoints = (imgep_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
imgep_union_polygon, imgep_covered_areas = calc_analytic_bc_coverage(imgep_reached_endpoints, epsilon=epsilon)
rs_reached_endpoints = (rs_reached_goals_embeddings-analytic_bc_space_low) / (analytic_bc_space_high-analytic_bc_space_low)
rs_union_polygon, rs_covered_areas = calc_analytic_bc_coverage(rs_reached_endpoints, epsilon=epsilon)
# calc clusters in behavior space
clusterer = hdbscan.HDBSCAN(min_cluster_size=10, cluster_selection_epsilon=0.1)
imgep_clusters_labels = clusterer.fit_predict(imgep_reached_endpoints)
rs_clusters_labels = clusterer.fit_predict(rs_reached_endpoints)
# project sampled params in 2D space with TSNE
imgep_sampled_params = jnp.array(list(imgep_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
rs_sampled_params = jnp.array(list(rs_experiment_history.intervention_params_library.y.values())).squeeze().transpose()
all_sampled_params = jnp.concatenate([imgep_sampled_params, rs_sampled_params])
tsne = TSNE(n_components=2)
all_sampled_params = tsne.fit_transform(all_sampled_params)
imgep_sampled_params, rs_sampled_params = all_sampled_params[:200], all_sampled_params[200:]
#@title [Figure 17]
fig_idx = 17
if nb_mode == "run":
fig = make_subplots(rows=2, cols=2, horizontal_spacing=0.01, vertical_spacing=0.1, subplot_titles=["<b>(a) Curiosity Search</b><br> ","","<b>(b) Random Search</b><br> ",""])
fig.update_layout(**default_layout, margin_t=40)
fig.update_annotations(font_size=12, x=0.1)
# Add points / shape contour per IMGEP cluster
for row_idx, (clusters_labels, reached_endpoints, sampled_params) in enumerate(zip([imgep_clusters_labels, rs_clusters_labels],
[imgep_reached_endpoints, rs_reached_endpoints],
[imgep_sampled_params, rs_sampled_params],
)):
for label_idx in set(clusters_labels):
label_idx = label_idx.item()
cluster_point_ids = jnp.where(clusters_labels==label_idx)[0]
z_points = reached_endpoints[cluster_point_ids]
i_points = sampled_params[cluster_point_ids]
if label_idx < 0:
show_shape = False
marker_size = 4
color_idx = 7
name = f"N/A"
else:
show_shape = True
marker_size = 4
color_idx = [4,3,9,6][label_idx] if row_idx==0 else [2,8][label_idx]
name = f"cluster {label_idx+1}"
# shape contour
if show_shape:
for col_idx, (points, eps) in enumerate(zip([i_points, z_points], [1.5, 0.05])):
poly = unary_union([Point(point).buffer(eps) for point in points])
poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
x, y = [], []
if poly.geom_type == 'MultiPolygon':
for geom in poly.geoms:
geom_x, geom_y = geom.exterior.coords.xy
x.append(np.array(geom_x))
y.append(np.array(geom_y))
elif poly.geom_type == 'Polygon':
geom_x, geom_y = poly.exterior.coords.xy
x.append(np.array(geom_x))
y.append(np.array(geom_y))
for i, (x, y) in enumerate(zip(x,y)):
fig.add_trace(go.Scatter(x=x, y=y, fill="toself", fillcolor=default_colors_shade[color_idx],
name=name, legendgroup=row_idx, showlegend=(i==0)&(col_idx==0),
line=dict(color=default_colors[color_idx]), hoverinfo="skip"), row=row_idx+1, col=col_idx+1)
# points
fig.add_trace(go.Scatter(x=z_points[:,0], y=z_points[:,1],
name=name, legendgroup=row_idx, showlegend=(label_idx==-1),
mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size),
row=row_idx+1, col=2)
fig.add_trace(go.Scatter(x=i_points[:,0], y=i_points[:,1],
name=name, legendgroup=row_idx, showlegend=False,
mode="markers", marker_color=default_colors[color_idx], marker_size=marker_size),
row=row_idx+1, col=1)
# Add background shape
for row_idx in [1,2]:
fig.add_vrect(x0=-0.1, x1=1.05,
fillcolor="#BAE1FF", opacity=.8,
line=dict(color="#6C8EBF", width=2),
annotation_text="Behavior Space <i>Z</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
layer="below", row=row_idx, col=2)
fig.add_vrect(x0=-40, x1=30,
fillcolor="#F6E785", opacity=.8,
line=dict(color="#C79714", width=2),
annotation_text="Intervention Space <i>I</i>", annotation_position="top", annotation_font=dict(color='black', size=14),
layer="below", row=row_idx, col=1)
# Update Layout
fig.update_layout(legend_tracegroupgap=320)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
# Serialize fig to json and save
if nb_save_outputs:
fig.write_json(f"figures/tuto1_fig_{fig_idx}.json")
elif nb_mode == "load":
fig = plotly.io.read_json(f"figures/tuto1_fig_{fig_idx}.json")
# Display Fig
width, height = 900, 800
t = f"Mapping between Intervention Space and Behavior space. "
if nb_renderer == "html":
html_fig = make_html_fig(fig_idx, fig, width, height, t)
display(HTML(html_fig))
elif nb_renderer == "img":
img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
display(Image(img_fig))
display(Markdown(img_title))
Finally we show that IMGEP facilitate the engineer task (e.g. when we are not sure on the apppropriate range for the parameter and/or goal space)
r= 1
# Random generator initially constrained to the tight range (r=1) of what we know is feasible from the default rollout
# but this time only for the 10% first runs, then we let the algorithm adapt their parameter space extent
n_random_batches = 1
n_imgep_batches = 9
constrained_random_intervention_generator_config = deepcopy(random_intervention_generator_config)
for node_idx in random_intervention_generator_config.controlled_node_ids:
constrained_random_intervention_generator_config.low.y[node_idx] = default_system_outputs.ys[node_idx].min() / r
constrained_random_intervention_generator_config.high.y[node_idx] = default_system_outputs.ys[node_idx].max() * r
constrained_random_intervention_generator, intervention_fn = create_intervention_module(constrained_random_intervention_generator_config)
# Goal-conditionned optimizer: remove (low, high) constraints
## and set local mutation amplitude based on what we know is feasible (r=1) from the default rollout
adaptive_gc_intervention_optimizer_config = deepcopy(gc_intervention_optimizer_config)
adaptive_gc_intervention_optimizer_config.low = jtu.tree_map(lambda node: jnp.zeros_like(node), gc_intervention_optimizer_config.low)
adaptive_gc_intervention_optimizer_config.high = None
adaptive_gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: r*(high-low),
constrained_random_intervention_generator.low,
constrained_random_intervention_generator.high)
adaptive_gc_intervention_optimizer = create_gc_intervention_optimizer_module(adaptive_gc_intervention_optimizer_config)
# Run Adaptive IMGEP variant
adaptive_imgep_experiment_data_save_folder = "data/adaptive_imgep_data"
if not os.path.exists(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle")):
run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
adaptive_imgep_experiment_data_save_folder,
constrained_random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
goal_generator, gc_intervention_selector, adaptive_gc_intervention_optimizer,
goal_embedding_encoder, goal_achievement_loss,
out_sanity_check=False, save_modules=False, save_logs=False)
# Run Adaptive RandomMut variant (IMGEP with null goal generation and random intervention selection)
null_goal_generator_config = deepcopy(goal_generator_config)
null_goal_generator_config.hypercube_scaling = 0.0
null_goal_generator = create_goal_generator_module(null_goal_generator_config)
random_intervention_selector_config = deepcopy(gc_intervention_selector_config)
random_intervention_selector_config.selector_type = "random"
random_intervention_selector = create_gc_intervention_selector_module(random_intervention_selector_config)
adaptive_rmut_experiment_data_save_folder = "data/adaptive_rmut_data"
if not os.path.exists(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle")):
run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
adaptive_rmut_experiment_data_save_folder,
constrained_random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
null_goal_generator, random_intervention_selector, adaptive_gc_intervention_optimizer,
goal_embedding_encoder, goal_achievement_loss,
out_sanity_check=False, save_modules=False, save_logs=False)
adaptive_imgep_experiment_history = DictTree.load(os.path.join(adaptive_imgep_experiment_data_save_folder, "history.pickle"))
adaptive_imgep_reached_goals_embeddings = adaptive_imgep_experiment_history.reached_goal_embedding_library
adaptive_rmut_experiment_history = DictTree.load(os.path.join(adaptive_rmut_experiment_data_save_folder, "history.pickle"))
adaptive_rmut_reached_goals_embeddings = adaptive_rmut_experiment_history.system_output_library.ys[:, jnp.array(observed_node_ids), -1]
adaptive_analytic_bc_space_low = jnp.minimum(jnp.nanmin(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmin(adaptive_rmut_reached_goals_embeddings, 0))
adaptive_analytic_bc_space_high = jnp.maximum(jnp.nanmax(adaptive_imgep_reached_goals_embeddings, 0), jnp.nanmax(adaptive_rmut_reached_goals_embeddings, 0))
adaptive_imgep_reached_endpoints = (adaptive_imgep_reached_goals_embeddings-adaptive_analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_imgep_union_polygon, adaptive_imgep_covered_areas = calc_analytic_bc_coverage(adaptive_imgep_reached_endpoints, epsilon=epsilon)
adaptive_rmut_reached_endpoints = (adaptive_rmut_reached_goals_embeddings-analytic_bc_space_low) / (adaptive_analytic_bc_space_high-adaptive_analytic_bc_space_low)
adaptive_rmut_union_polygon, adaptive_rmut_covered_areas = calc_analytic_bc_coverage(adaptive_rmut_reached_endpoints, epsilon=epsilon)
# We test only 40 last trajectories (to go faster)
eval_system_outputs_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.system_output_library)
eval_intervention_params_library = jtu.tree_map(lambda node: node[-2*batch_size:], imgep_experiment_history.intervention_params_library)
def calc_settling_time(dist_vals, settling_time_threshold):
# assume normalized dist_vals starting from 1 and finishing at 0
settling_time = jnp.where(~(dist_vals < settling_time_threshold), size=len(dist_vals), fill_value=-1)[0].max()
return settling_time
def calc_travelling_time(trajectories):
distance_travelled = jnp.cumsum(jnp.sqrt(jnp.sum(jnp.diff(trajectories, axis=-1)** 2, axis=-2)), axis=-1)
distance_travelled = distance_travelled / distance_travelled.max(-1) # normalize between 0 and 1
T10 = jnp.where(distance_travelled >= 0.1, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
T90 = jnp.where(distance_travelled >= 0.9, size=distance_travelled.shape[-1], fill_value=-1)[0][0]
return T10, T90
def calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold):
trajectories = trajectories[..., 1:] #remove first step (dont wanna take into account big jumps happening in first step)
# settling time: first time T such that the distance between y(t) and yfinal ≤ 0.02 × |yfinal – yinit| for t ≥ T
# normalize such that origin is final point and unit=(end-origin)
extent = (trajectories.max(-1) - trajectories.min(-1))
extent = extent.at[extent == 0.].set(1.)
normalized_trajectories = trajectories / extent[..., jnp.newaxis]
distance_to_target = jnp.linalg.norm(normalized_trajectories - normalized_trajectories[..., -1][..., jnp.newaxis], axis=1)
distance_to_target = distance_to_target / distance_to_target[:, 0][:, jnp.newaxis]
settling_times = vmap(calc_settling_time, in_axes=(0, None))(distance_to_target, settling_time_threshold)
# travelling time: time it takes for the response to travel from 10% to 90% of the way from yinit to yfinal
T10s, T90s = vmap(calc_travelling_time)(normalized_trajectories)
# detours (duration and area)
detours_duration = []
detours_area = []
detours_timesteps = []
for sample_idx in range(len(distance_to_target)):
detour_timesteps = []
detour_duration = 0.
detour_area = 0.
if settling_times[sample_idx] > 0:
cur_distance_to_target = distance_to_target[sample_idx, :settling_times[sample_idx]]
is_distance_increasing = jnp.concatenate(
[jnp.array([False]), jnp.diff(cur_distance_to_target) > 0])
is_distance_decreasing = jnp.concatenate(
[jnp.array([True]), jnp.diff(cur_distance_to_target) < 0])
start_detour_timesteps = jnp.where(is_distance_decreasing[:-1] & is_distance_increasing[1:])[0]
if len(start_detour_timesteps) > 0:
start_detour_dist_vals = cur_distance_to_target[start_detour_timesteps]
end_detour_timesteps = []
for start_detour_timestep, start_detour_dist_val in zip(start_detour_timesteps, start_detour_dist_vals):
possible_detour_timesteps = jnp.where((cur_distance_to_target[:-1] >= start_detour_dist_val) &
(cur_distance_to_target[1:] <= start_detour_dist_val))[0] + 1
# take the first time step (after start_detour_timestep) where distance curve is crossing back y=start_detour_dist_val
# if no crossing back before settling time, we consider settling time as the end of the detour
possible_end_detour_timesteps = possible_detour_timesteps[possible_detour_timesteps > start_detour_timestep]
if len(possible_end_detour_timesteps) > 0:
end_detour_timestep = possible_end_detour_timesteps[0]
else:
end_detour_timestep = settling_times[sample_idx]-1
end_detour_timesteps.append(end_detour_timestep)
# calc union of intervals (in case some overlaps due to noise)
detour_timesteps = jnp.where(jnp.array([(jnp.arange(len(cur_distance_to_target)) >= start) &
(jnp.arange(len(cur_distance_to_target)) <= end)
for (start, end) in
zip(start_detour_timesteps, end_detour_timesteps)]).any(0))[0]
detour_duration = len(detour_timesteps)
rel_start_detours_timesteps = jnp.concatenate([jnp.array([0]), jnp.where((detour_timesteps[1:] - detour_timesteps[:-1]) > 1)[0]+1])
rel_end_detours_timesteps = jnp.roll(rel_start_detours_timesteps-1, -1)
detour_polygon = Polygon()
valid_detour_timesteps = jnp.empty((0,))
for start, end in zip(rel_start_detours_timesteps, rel_end_detours_timesteps):
detour_points = normalized_trajectories[sample_idx][:, detour_timesteps[start:end]].transpose()
if len(detour_points) >= 3:
cur_detour_polygon = Polygon([*detour_points])
if cur_detour_polygon.is_valid:
detour_polygon = unary_union([detour_polygon, cur_detour_polygon])
valid_detour_timesteps = jnp.concatenate([valid_detour_timesteps, detour_timesteps[start:end]])
detour_area = detour_polygon.area
detours_timesteps.append(detour_timesteps)
detours_duration.append(detour_duration)
detours_area.append(detour_area)
trajectories_statistics = DictTree()
trajectories_statistics.distance_to_target = distance_to_target
trajectories_statistics.settling_times = settling_times
trajectories_statistics.T10s = T10s
trajectories_statistics.T90s = T90s
trajectories_statistics.detours_timesteps = detours_timesteps
trajectories_statistics.detours_duration = jnp.array(detours_duration)
trajectories_statistics.detours_area = jnp.array(detours_area)
return trajectories_statistics
deltaT = system_rollout_config.deltaT
settling_time_threshold=0.02
trajectories = eval_system_outputs_library.ys[:, jnp.array(observed_node_ids), :]
trajectories_statistics = calc_trajectories_statistics(trajectories, deltaT, settling_time_threshold)
# Update the system_rollout to run for a bit longer
eval_system_rollout_config = deepcopy(system_rollout_config)
eval_system_rollout_config.n_secs = int(system_rollout_config.n_secs*1.2)
eval_system_rollout_config.n_system_steps = int(eval_system_rollout_config.n_secs/eval_system_rollout_config.deltaT)
eval_system_rollout = create_system_rollout_module(eval_system_rollout_config)
batched_eval_system_rollout = vmap(eval_system_rollout, in_axes=(0, None, 0, None, 0))
# perturbation hyperparams
perturbation_min_duration = 50
perturbation_max_duration = 500
T10 = jnp.median((trajectories_statistics.T10s+1))*deltaT
T90 = jnp.median((trajectories_statistics.T90s+1))*deltaT
start = T10
end = min(max(T90, start+perturbation_min_duration), start+perturbation_max_duration)
test_tasks = {
"noise_std": [0.001, 0.005, 0.01],
"noise_period": [10, 5, 1],
"push_magnitude": [0.05, 0.1, 0.15],
"push_number": [1, 2, 3],
"wall_length": [0.05, 0.1, 0.15],
"wall_number": [1, 2, 3]
}
def get_perturbation_generator_config(var_name, var_val):
perturbation_generator_config = Dict()
if var_name.split("_")[0] == "noise":
perturbation_generator_config = Dict()
perturbation_generator_config.perturbation_type = "noise"
perturbation_generator_config.perturbed_node_ids = random_intervention_generator_config.controlled_node_ids
if var_name.split("_")[1] == "std":
n_noises = int((end-start)//test_tasks["noise_period"][1])
perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]] #5
perturbation_generator_config.std = var_val
elif var_name.split("_")[1] == "period":
n_noises = int((end - start) // float(var_val))
perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + n_noises)[1:-1]]
perturbation_generator_config.std = test_tasks["noise_std"][1]
elif var_name.split("_")[0] == "push":
perturbation_generator_config = Dict()
perturbation_generator_config.perturbation_type = "push"
perturbation_generator_config.perturbed_node_ids = observed_node_ids
if var_name.split("_")[1] == "magnitude":
perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
perturbation_generator_config.magnitude = var_val
elif var_name.split("_")[1] == "number":
perturbation_generator_config.perturbed_intervals = [[int(t)-deltaT/2, int(t)+deltaT/2] for t in jnp.linspace(start, end, 2 + var_val)[1:-1]]
perturbation_generator_config.magnitude = test_tasks["push_magnitude"][1]
elif var_name.split("_")[0] == "wall":
perturbation_generator_config = Dict()
perturbation_generator_config.perturbation_type = "wall"
perturbation_generator_config.wall_type = "force_field"
perturbation_generator_config.perturbed_intervals = [[0, eval_system_rollout_config.n_secs]]
perturbation_generator_config.perturbed_node_ids = observed_node_ids
perturbation_generator_config.walls_sigma = [1e-2, 1e-4]
if var_name.split("_")[1] == "length":
perturbation_generator_config.n_walls = 1
perturbation_generator_config.walls_intersection_window = [[0.1, 0.9]]
perturbation_generator_config.walls_length_range = [[var_val, var_val]]
elif var_name.split("_")[1] == "number":
perturbation_generator_config.n_walls = var_val
walls_windows_pos = jnp.linspace(0, 1, 2 + var_val)
walls_spacing = (walls_windows_pos[1] - walls_windows_pos[0]) * 3 / 4
perturbation_generator_config.walls_intersection_window = [[t - walls_spacing, t + walls_spacing] for t in walls_windows_pos[1:-1]]
perturbation_generator_config.walls_length_range = [[test_tasks["wall_length"][1], test_tasks["wall_length"][1]]] * var_val
return perturbation_generator_config
n_perturbations = 3
root_test_save_folder = "data/robustness_tests"
for test_task_var_name, test_task_var_range in test_tasks.items():
for test_task_var_val in test_task_var_range:
test_save_folder=f"{root_test_save_folder}/{test_task_var_name}_{test_task_var_val}"
print(test_save_folder)
if not os.path.exists(test_save_folder):
perturbation_generator_config = get_perturbation_generator_config(test_task_var_name, test_task_var_val)
perturbation_generator, perturbation_fn = create_perturbation_module(perturbation_generator_config)
run_robustness_tests(jax_platform_name, seed, n_perturbations, test_save_folder,
eval_system_outputs_library,
eval_intervention_params_library, intervention_fn,
perturbation_generator, perturbation_fn,
eval_system_rollout, null_rollout_statistics_encoder,
out_sanity_check=False, save_modules=False, save_logs=False)
Trajectory-based energy landscapes of gene regulatory networks paper.
imgep_trajectories = imgep_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(imgep_trajectories).any(-1).any(-1)) & (imgep_trajectories>=-1e-6).all(-1).all(-1)
imgep_points = jnp.concatenate(imgep_trajectories[is_valid_bool], axis=-1)
rs_trajectories = rs_experiment_history.system_output_library.ys[:, observed_node_ids, :]
is_valid_bool = ~(jnp.isnan(rs_trajectories).any(-1).any(-1)) & (rs_trajectories>=-1e-6).all(-1).all(-1)
rs_points = jnp.concatenate(rs_trajectories[is_valid_bool], axis=-1)
imgep_perturbed_points = []
for task_name in test_tasks.keys():
for task_var in test_tasks[task_name]:
test_experiment_history = DictTree.load(os.path.join(f"{root_test_save_folder}/{task_name}_{task_var}", "history.pickle"))
perturbed_trajectories = test_experiment_history.system_output_library.ys[:, :, observed_node_ids, :].reshape(len(eval_system_outputs_library.ys)*n_perturbations, 2, -1)
is_valid_bool = ~(jnp.isnan(perturbed_trajectories).any(-1).any(-1)) & (perturbed_trajectories>=-1e-6).all(-1).all(-1)
imgep_perturbed_points.append(jnp.concatenate(perturbed_trajectories[is_valid_bool], axis=-1))
imgep_perturbed_points = jnp.concatenate(imgep_perturbed_points, axis=-1)
all_points = jnp.concatenate([imgep_points, rs_points, imgep_perturbed_points], axis=-1)
ymin, ymax = all_points.min(-1), all_points.max(-1)
del all_points
results = {}
for k, points in zip(["(a) Random Search", "(b) IMGEP", " (c) IMGEP+perturbations"], [rs_points, imgep_points, imgep_perturbed_points]):
H, xedges, yedges = jnp.histogram2d(
x=points[0, :],
y=points[1, :],
bins=10,
range=jnp.stack([ymin, ymax]).transpose()
)
H = H.transpose()
# Compute probability distribution P
H = H.at[jnp.where(H == 0)].set(1)
U = -jnp.log(H / H.sum())
# Calculate energy Landscape
bin_sizex = xedges[1] - xedges[0]
bin_sizey = yedges[1] - yedges[0]
x = xedges[:-1] + bin_sizex / 2
y = yedges[:-1] + bin_sizey / 2
z = U.flatten()
interp = RegularGridInterpolator((x,y), U, method="cubic", bounds_error=False)
xi = jnp.linspace(ymin[0], ymax[0], 100)
yi = jnp.linspace(ymin[1], ymax[1], 100)
xi, yi = jnp.meshgrid(xi, yi)
zi = interp((xi, yi)).transpose()
# save results
results[k] = (xi, yi, zi)
# Convert system rollout in batch mode
batched_system_rollout = vmap(system_rollout, in_axes=(0, None, 0, None, 0), out_axes=(0,None))
# Get most robust trajectories
mean_sensitivities = jnp.nanmean(jnp.stack(list(data.values())), axis=0)
Q25 = jnp.percentile(mean_sensitivities, 25)
rob_sample_ids = jnp.where(mean_sensitivities < Q25)[0]
# "healthy" and "disease" regions polygons
regions_poly = {}
for label_idx, region_name in zip([1,2], ("healthy", "disease")):
cluster_point_ids = jnp.where(imgep_clusters_labels==label_idx)[0]
z_points = imgep_reached_endpoints[cluster_point_ids]
eps=0.05
poly = unary_union([Point(point).buffer((eps)) for point in z_points])
poly = poly.buffer(eps*5, join_style=1).buffer(-eps*5, join_style=1)
poly = affinity.scale(poly, xfact=(analytic_bc_space_high-analytic_bc_space_low)[0], yfact=(analytic_bc_space_high-analytic_bc_space_low)[1], origin=(0,0,0))
poly = affinity.translate(poly, xoff=analytic_bc_space_low[0], yoff=analytic_bc_space_low[1])
regions_poly[region_name] = poly
# Select pathways in "disease" region
disease_pathways_ids = []
for sample_idx in rob_sample_ids:
if regions_poly["disease"].contains(Point(eval_system_outputs_library.ys[sample_idx, observed_node_ids, -1])):
disease_pathways_ids.append(sample_idx)
disease_pathways_ids = jnp.array(disease_pathways_ids)
# Target endpoint: centroid of "healthy region"
target_endpoint = jnp.array(regions_poly["healthy"].centroid.coords)
print(f"Target: {target_endpoint}")
# Create intervention function
class CustomInterventionFn(grn.PiecewiseIntervention):
def apply(self, key, y, y_, w, w_, c, c_, interval_idx, intervention_params):
y = lax.cond(interval_idx.sum() == 0, self.apply_init, self.apply_clamping, y, interval_idx, intervention_params)
return y, w, c
def apply_init(self, y, interval_idx, intervention_params):
return intervention_params.y0
def apply_clamping(self, y, interval_idx, intervention_params):
for y_idx, clamp_vals in intervention_params.clamping.items():
y = y.at[y_idx].set(clamp_vals[interval_idx-1])
return y
controlled_node = "MEKPP"
controlled_intervals = [[i, i+10] for i in range(10,100,10)]
custom_intervention_fn = CustomInterventionFn(grn.TimeToInterval(intervals=[[-deltaT/2., deltaT/2.]]+controlled_intervals))
# Perurbations functions (similar than in Part 1)
batched_noise_perturbation_generator = vmap(noise_perturbation_generator)
push_perturbation_generator_config.perturbed_intervals = [[int((T90+T10)/2.)-deltaT/2, int((T90+T10)/2.)+deltaT/2]]
push_perturbation_generator, _ = create_perturbation_module(push_perturbation_generator_config)
batched_push_perturbation_generator = vmap(push_perturbation_generator)
batched_wall_perturbation_generator = vmap(wall_perturbation_generator)
print(f"We search a stepwise intervention applied on {controlled_node} and on time intervals {controlled_intervals}, that will allow to get unstuck from 'disease' to 'healthy' region under a various perturbations")
# Define train loss: distance to center of "healthy" area (we want to achive a state with low ERK and high RKIPP_RP concentrations)
def pathway_evaluate_worker_fn(key, batched_intervention_params, intervention_fn, batched_system_rollout, observed_node_ids, target_endpoint):
# rollout the system without perturbation
key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
batched_system_outputs, _ = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, grn.NullIntervention(), None)
# rollout the system with noise perturbation
key, *subkeys = jrandom.split(key, num=batch_size+1)
batched_noise_perturbation_params, _ = batched_noise_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
batched_system_outputs_noise, _ = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, noise_perturbation_fn, batched_noise_perturbation_params)
loss = jnp.sqrt(jnp.square((batched_system_outputs_noise.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())
# rollout the system with push perturbation
key, *subkeys = jrandom.split(key, num=batch_size+1)
batched_push_perturbation_params, _ = batched_push_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
batched_system_outputs_push, _ = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, push_perturbation_fn, batched_push_perturbation_params)
loss += jnp.sqrt(jnp.square((batched_system_outputs_push.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())
# rollout the system with wall perturbation
key, *subkeys = jrandom.split(key, num=batch_size+1)
batched_wall_perturbation_params, _ = batched_wall_perturbation_generator(jnp.array(subkeys), batched_system_outputs)
key, *subkeys = jrandom.split(key, num=batched_intervention_params.y0.shape[0]+1)
batched_system_outputs_wall, _ = batched_system_rollout(jnp.array(subkeys), intervention_fn, batched_intervention_params, wall_perturbation_fn, batched_wall_perturbation_params)
loss += jnp.sqrt(jnp.square((batched_system_outputs_wall.ys[:, jnp.array(observed_node_ids), -1] - target_endpoint)).sum())
return loss.mean(), [batched_system_outputs, batched_system_outputs_noise, batched_system_outputs_push, batched_system_outputs_wall], batched_wall_perturbation_params
pathway_evaluate_worker_fn = jit(jtu.Partial(pathway_evaluate_worker_fn, intervention_fn=custom_intervention_fn,
batched_system_rollout=batched_system_rollout, observed_node_ids=observed_node_ids, target_endpoint=target_endpoint))
# Optimize intervention (with random search)
n_trials = 20
batch_size = len(disease_pathways_ids)
best_loss = 1e7
for trial_idx in range(n_trials):
# Generate random intervention on (Raf1, MEKPP)
batched_interventions_params = DictTree()
# Set last state of disease pathways as init state (via intervention)
batched_interventions_params.y0 = eval_system_outputs_library.ys[disease_pathways_ids, :, -1]
y_idx = system_rollout.grn_step.y_indexes[controlled_node]
key, subkey = jrandom.split(key)
batched_interventions_params.clamping[y_idx] = jnp.tile(jrandom.uniform(subkey, shape=(len(controlled_intervals),),
minval=0.05*eval_system_outputs_library.ys[disease_pathways_ids,y_idx, -1].min(),
maxval=20*eval_system_outputs_library.ys[disease_pathways_ids,y_idx, -1].max()),
(batch_size, 1))
# Calc loss
key, subkey = jrandom.split(key)
trial_loss, batched_outputs, wall_params = pathway_evaluate_worker_fn(subkey, batched_interventions_params)
print(trial_loss)
if trial_loss < best_loss:
best_loss = trial_loss
best_params = batched_interventions_params
best_outputs = batched_outputs
display_wall_params = wall_params
👉 Our behavioral catalog allows to identify robust reachable goal states, in particular here to undesired "disease" states (which were not identified with random search) and to develop interventions allowing to robustly reset them