Authors | Affiliation | Published |
---|---|---|
Mayalen Etcheverry | INRIA, Flowers team, Poietis | September, 2023 |
Clément Moulin-Frier | INRIA, Flowers team | |
Pierre-Yves Oudeyer | INRIA, Flowers team | |
Michael Levin | The Levin Lab, Tufts University | Reproduce in Notebook |
TL;DR
This second tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks, and more particularly the last section "Reuse of the framework as an alternative strategy to gene circuit engineering".
When simulating synthetic gene regulatory network, we typically assume one family of ODE equations. Here we use the transcriptional gene circuit model with a simple model step defined as:
$\frac{d{y}_i}{dt}=\phi \left({\sum}_j{W}_{ij}{y}_j+{B}_i\right)-{k}_i{y}_i$
Here, we use $k_i=1, W_{ij}\in[-30,30], B_{i}\in[-10,10]$ and with these parameters species concentrations are constrained in $y\in[0,1]$
@jit
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
class SimpleModelStep(eqx.Module):
def __init__(self, **kwargs):
super().__init__()
@jit
def __call__(self, y, w, c, t, deltaT):
n = len(y)
W = c[:n * n].reshape((n, n))
B = c[n * n:(n + 1) * n]
y_new = y + deltaT * (sigmoid(W @ y + B ) - y)
t_new = t + deltaT
w_new = w
return y_new, w_new, c, t_new
Now that we have define the new ModelStep function, AutoDiscJax allows us to simulate system rollout (and applying different kind of interventions on it) in the same manner that we did for biological networks in the first tutorial.
Let's instantitate the system rollout module.
n = 3 #number of nodes
deltaT = 0.01
n_secs = 100
n_steps = int(n_secs/deltaT)
c = jnp.empty(((n + 1) * n, ))
c_low = jnp.array([-30.]*n**2 + [-10.]*n)
c_high = jnp.array([30.]*n**2 + [10.]*n)
grn_step=SimpleModelStep()
y0=jnp.empty(shape=(n,))
y0_low = 0.
y0_high = 1.
w0 = jnp.array([])
system_rollout = grn.GRNRollout(n_steps=n_steps, y0=y0, w0=w0, c=c, t0=0.0, deltaT=deltaT, grn_step=grn_step)
Let's now use intervention to (randomly) set the GRN's init state (y0) and kinematic parameters (c)
# Create an intervention generator and an intervention_fn modules to set the initial state and the kinematic parameters to random values
random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_intervals = [[0, deltaT/2.0]]
intervention_params_tree = DictTree()
intervention_params_low = DictTree()
intervention_params_high = DictTree()
for y_idx in range(len(y0)):
intervention_params_tree.y[y_idx] = "placeholder"
intervention_params_low.y[y_idx] = y0_low
intervention_params_high.y[y_idx] = y0_high
for c_idx in range(len(c)):
intervention_params_tree.c[c_idx] = "placeholder"
intervention_params_low.c[c_idx] = c_low[c_idx]
intervention_params_high.c[c_idx] = c_high[c_idx]
random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),), intervention_params_tree)
random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
random_intervention_generator_config.low = intervention_params_low
random_intervention_generator_config.high = intervention_params_high
random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)
# Run the system with the sample intervention
key, subkey = jrandom.split(key)
random_system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn, intervention_params=intervention_params)
For the IMGEP goal space, we use the image space of the discrete fourier transform of the 1d-signal $y[n=0]$.
observed_node_ids = [0]
goal_embedding_encoder_config = Dict()
goal_embedding_encoder_config.encoder_type = "filter"
goal_embedding_tree = "placeholder"
goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids)*(system_rollout.n_steps//2//2+1), ), goal_embedding_tree)
goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: jnp.fft.rfft(system_outputs.ys[observed_node_ids, -system_rollout.n_steps//2:]).flatten())
goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)
# example: encode system outputs
key, subkey = jrandom.split(key)
reached_goal_embedding, log_data = goal_embedding_encoder(subkey, random_system_outputs)
print(reached_goal_embedding.shape)
(2501,)
Distance in the goal space measures average difference in spectral amplitude.
goal_achievement_loss_config = Dict()
goal_achievement_loss_config.loss_type = "custom"
goal_achievement_loss_config.loss_f = jtu.Partial(lambda reached_goal, target_goal: abs(reached_goal - target_goal).mean())
goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)
# example
target_goal_embedding = y_descriptors
key, subkey = jrandom.split(key)
gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
print(gc_loss)
4.0033674
For the goal generator, goal-conditionned intervention selector and optimizer we re-use the same simple variants that the one used in the first tutorial.
goal_generator_config = DictTree()
goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
goal_generator_config.low = None
goal_generator_config.high = None
goal_generator_config.generator_type = "hypercube"
goal_generator_config.hypercube_scaling = 1.3
goal_generator = create_goal_generator_module(goal_generator_config)
# example
key, subkey = jrandom.split(key)
next_target_goal_embedding, log_data = goal_generator(subkey, target_goal_embedding[jnp.newaxis], jnp.stack([reached_goal_embedding,target_goal_embedding]))
print(next_target_goal_embedding.shape)
(2501,)
gc_intervention_selector_config = Dict()
gc_intervention_selector_config.selector_type = "nearest_neighbor"
gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
gc_intervention_selector_config.k = 1
gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)
# example
key, subkey = jrandom.split(key)
source_interventions_idx, log_data = gc_intervention_selector(subkey, next_target_goal_embedding, jnp.stack([reached_goal_embedding, target_goal_embedding]))
print(source_interventions_idx)
1
gc_intervention_optimizer_config = Dict()
gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
gc_intervention_optimizer_config.low = random_intervention_generator.low
gc_intervention_optimizer_config.high = random_intervention_generator.high
gc_intervention_optimizer_config.optimizer_type = "EA"
gc_intervention_optimizer_config.n_optim_steps = 1
gc_intervention_optimizer_config.n_workers = 1
gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low),
gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)
gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
intervention_fn=intervention_fn, system_rollout=system_rollout,
goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
rollout_statistics_encoder=null_rollout_statistics_encoder
)
# example
key, subkey = jrandom.split(key)
optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, intervention_params, next_target_goal_embedding, reached_goal_embedding)
print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))
{'c': {0: (1,), 1: (1,), 2: (1,), 3: (1,), 4: (1,), 5: (1,), 6: (1,), 7: (1,), 8: (1,), 9: (1,), 10: (1,), 11: (1,)}, 'y': {0: (1,), 1: (1,), 2: (1,)}}
Now that we have defined the IMGEP internal models, we can run the IMGEP experimental pipeline. As in the previous tutorial, we compare it with a random exploration strategy given the same experimental budget of experiments. Here we define a total of N=5000 experiments, with a batch size of 100.
jax_platform_name = "cpu"
seed = 0
# Run IMGEP
n_random_batches = 10
n_imgep_batches = 40
batch_size = 100
imgep_experiment_data_save_folder = "data/periodic_imgep_data"
if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
imgep_experiment_data_save_folder,
random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
goal_generator, gc_intervention_selector, gc_intervention_optimizer,
goal_embedding_encoder, goal_achievement_loss,
out_sanity_check=False, save_modules=False, save_logs=False)
# Run Random Search
rs_experiment_data_save_folder = "data/periodic_rs_data"
if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size,
rs_experiment_data_save_folder,
random_intervention_generator, intervention_fn,
null_perturbation_generator, null_perturbation_fn,
system_rollout, null_rollout_statistics_encoder,
out_sanity_check=False, save_modules=False, save_logs=False)
imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
print(imgep_reached_goals_embeddings.shape)
rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle"))
key, *subkeys = jrandom.split(key, num=len(imgep_reached_goals_embeddings)+ 1)
rs_reached_goals_embeddings, _ = vmap(goal_embedding_encoder)(jnp.array(subkeys), rs_experiment_history.system_output_library)
print(rs_reached_goals_embeddings.shape)
(5000, 2501) (5000, 2501)
rs_is_periodic_bool, rs_offset_vals, rs_ampl_vals, rs_freq_vals = is_periodic(rs_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
rs_is_periodic_ids = jnp.where(rs_is_periodic_bool)[0]
print(f"Random Search has discovered {len(rs_is_periodic_ids)} oscillator circuits out of N={len(rs_is_periodic_bool)} trials")
imgep_is_periodic_bool, imgep_offset_vals, imgep_ampl_vals, imgep_freq_vals = is_periodic(imgep_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
imgep_is_periodic_ids = jnp.where(imgep_is_periodic_bool)[0]
print(f"Curiosity Search has discovered {len(imgep_is_periodic_ids)} oscillator circuits out of N={len(imgep_is_periodic_bool)} trials")
Random Search has discovered 42 oscillator circuits out of N=5000 trials Curiosity Search has discovered 1167 oscillator circuits out of N=5000 trials
Here, the analytic BC space is the space of (amplitude $A$, main frequency $\omega$, offset $b$) of the discovered oscillators, where $(A,\omega,b)$ are estimated by Autodiscjax is_periodic
util function.
Diversity is measured with the QD-score, a binning-based metric where the BC space is discretized into a collection of bins and the diversity is quantified as the number of bins filled over the course of exploration.
We opt for a regular binning where each dimension of the BC space is discretized into equally sized bins, using 20 bins per dimension.
We do not use the threshold-coverage metric as in tutorial 1 as it is difficult to compute in n-dimensional spaces where $n\ge3$.
imgep_reached_goals_embeddings = jnp.stack([imgep_offset_vals.at[~imgep_is_periodic_bool].set(0.0),
imgep_ampl_vals.at[~imgep_is_periodic_bool].set(0.0),
imgep_freq_vals.at[~imgep_is_periodic_bool].set(0.0)], -1)
rs_reached_goals_embeddings = jnp.stack([rs_offset_vals.at[~rs_is_periodic_bool].set(0.0),
rs_ampl_vals.at[~rs_is_periodic_bool].set(0.0),
rs_freq_vals.at[~rs_is_periodic_bool].set(0.0)], -1)
analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))
analytic_bc_space_extent = jnp.stack([analytic_bc_space_low, analytic_bc_space_high]).transpose()
def calc_analytic_bc_coverage_histograms(reached_goals_embeddings, analytic_bc_space_extent, n_bins=20, every_n_steps=1):
def f(carry, goal_embedding):
Hf = carry
cur_Hf, _ = jnp.histogramdd(goal_embedding[jnp.newaxis], bins=n_bins,range=analytic_bc_space_extent)
Hf = Hf + cur_Hf.transpose()
return Hf, Hf
final_coverage_histogram, coverage_histograms = lax.scan(f, jnp.zeros((n_bins, n_bins, n_bins), dtype=jnp.int32), reached_goals_embeddings[::every_n_steps])
return coverage_histograms
imgep_coverage_histograms = calc_analytic_bc_coverage_histograms(imgep_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)
rs_coverage_histograms = calc_analytic_bc_coverage_histograms(rs_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)