Tuto2 Tuto1 Paper

AI-driven Automated Discovery Tools for Synthetic Circuit Engineering

Authors Affiliation Published
Mayalen Etcheverry INRIA, Flowers team, Poietis September, 2023
Clément Moulin-Frier INRIA, Flowers team
Pierre-Yves Oudeyer INRIA, Flowers team
Michael Levin The Levin Lab, Tufts University Reproduce in Notebook

Introduction

TL;DR

This second tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks, and more particularly the last section "Reuse of the framework as an alternative strategy to gene circuit engineering".

📝 How to follow this tutorial

ModelStep function

When simulating synthetic gene regulatory network, we typically assume one family of ODE equations. Here we use the transcriptional gene circuit model with a simple model step defined as:

$\frac{d{y}_i}{dt}=\phi \left({\sum}_j{W}_{ij}{y}_j+{B}_i\right)-{k}_i{y}_i$

Here, we use $k_i=1, W_{ij}\in[-30,30], B_{i}\in[-10,10]$ and with these parameters species concentrations are constrained in $y\in[0,1]$

@jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

class SimpleModelStep(eqx.Module):
    def __init__(self, **kwargs):
        super().__init__()

    @jit
    def __call__(self, y, w, c, t, deltaT):

        n = len(y)
        W = c[:n * n].reshape((n, n))
        B = c[n * n:(n + 1) * n]
        y_new = y + deltaT * (sigmoid(W @ y + B ) - y)
        t_new = t + deltaT
        w_new = w

        return y_new, w_new, c, t_new

Part 1: Curiosity-driven search for the discovery of diverse oscillator circuits

Experiment Pipeline and Modules

System Rollout module

Now that we have define the new ModelStep function, AutoDiscJax allows us to simulate system rollout (and applying different kind of interventions on it) in the same manner that we did for biological networks in the first tutorial.

Let's instantitate the system rollout module.

n = 3 #number of nodes
deltaT = 0.01
n_secs = 100
n_steps = int(n_secs/deltaT)

c = jnp.empty(((n + 1) * n, ))
c_low = jnp.array([-30.]*n**2 + [-10.]*n)
c_high = jnp.array([30.]*n**2 + [10.]*n)
grn_step=SimpleModelStep()

y0=jnp.empty(shape=(n,))
y0_low = 0.
y0_high = 1.

w0 = jnp.array([])

system_rollout = grn.GRNRollout(n_steps=n_steps, y0=y0, w0=w0, c=c, t0=0.0, deltaT=deltaT, grn_step=grn_step)

Random Intervention Generator

Let's now use intervention to (randomly) set the GRN's init state (y0) and kinematic parameters (c)

# Create an intervention generator and an intervention_fn modules to set the initial state and the kinematic parameters to random values
random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_intervals = [[0, deltaT/2.0]]

intervention_params_tree = DictTree()
intervention_params_low = DictTree()
intervention_params_high = DictTree()
for y_idx in range(len(y0)):
    intervention_params_tree.y[y_idx] = "placeholder"
    intervention_params_low.y[y_idx] = y0_low
    intervention_params_high.y[y_idx] = y0_high
for c_idx in range(len(c)):
    intervention_params_tree.c[c_idx] = "placeholder"
    intervention_params_low.c[c_idx] = c_low[c_idx]
    intervention_params_high.c[c_idx] = c_high[c_idx]

random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),), intervention_params_tree)
random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
random_intervention_generator_config.low = intervention_params_low
random_intervention_generator_config.high = intervention_params_high

random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)

# Run the system with the sample intervention
key, subkey = jrandom.split(key)
random_system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn, intervention_params=intervention_params)
Figure 1: Simulation results of the mathematical modeling for random kinetic parameters (c) and initial gene expression levels (y0).

Goal Embedding Encoder

For the IMGEP goal space, we use the image space of the discrete fourier transform of the 1d-signal $y[n=0]$.

Figure 2: Illustration of Fourier descriptors, which are here used as goal representation by the IMGEP.
observed_node_ids = [0]

goal_embedding_encoder_config = Dict()
goal_embedding_encoder_config.encoder_type = "filter"
goal_embedding_tree = "placeholder"
goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids)*(system_rollout.n_steps//2//2+1), ), goal_embedding_tree)
goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: jnp.fft.rfft(system_outputs.ys[observed_node_ids, -system_rollout.n_steps//2:]).flatten())


goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)
# example: encode system outputs
key, subkey = jrandom.split(key)
reached_goal_embedding, log_data = goal_embedding_encoder(subkey, random_system_outputs)
print(reached_goal_embedding.shape)
(2501,)

Goal-conditioned Achievement Loss

Distance in the goal space measures average difference in spectral amplitude.

goal_achievement_loss_config = Dict()
goal_achievement_loss_config.loss_type = "custom" 
goal_achievement_loss_config.loss_f = jtu.Partial(lambda reached_goal, target_goal: abs(reached_goal - target_goal).mean())
goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)
# example
target_goal_embedding = y_descriptors
key, subkey = jrandom.split(key)
gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
print(gc_loss)
4.0033674

Goal Generator

For the goal generator, goal-conditionned intervention selector and optimizer we re-use the same simple variants that the one used in the first tutorial.

goal_generator_config = DictTree()
goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
goal_generator_config.low = None
goal_generator_config.high = None
goal_generator_config.generator_type = "hypercube"
goal_generator_config.hypercube_scaling = 1.3

goal_generator = create_goal_generator_module(goal_generator_config)
# example
key, subkey = jrandom.split(key)
next_target_goal_embedding, log_data = goal_generator(subkey, target_goal_embedding[jnp.newaxis], jnp.stack([reached_goal_embedding,target_goal_embedding]))
print(next_target_goal_embedding.shape)
(2501,)

Goal-conditioned Intervention Selector

gc_intervention_selector_config = Dict()
gc_intervention_selector_config.selector_type = "nearest_neighbor"
gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
gc_intervention_selector_config.k = 1

gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)
# example
key, subkey = jrandom.split(key)
source_interventions_idx, log_data = gc_intervention_selector(subkey, next_target_goal_embedding, jnp.stack([reached_goal_embedding, target_goal_embedding]))
print(source_interventions_idx)
1

Goal-conditioned Intevention Optimizer

gc_intervention_optimizer_config = Dict()
gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
gc_intervention_optimizer_config.low = random_intervention_generator.low
gc_intervention_optimizer_config.high = random_intervention_generator.high
gc_intervention_optimizer_config.optimizer_type = "EA"
gc_intervention_optimizer_config.n_optim_steps = 1
gc_intervention_optimizer_config.n_workers = 1
gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low), 
                                                               gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)

gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
                                        perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
                                        intervention_fn=intervention_fn, system_rollout=system_rollout,
                                        goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
                                        rollout_statistics_encoder=null_rollout_statistics_encoder
                                        )
# example
key, subkey = jrandom.split(key)
optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, intervention_params, next_target_goal_embedding, reached_goal_embedding)
print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))
{'c': {0: (1,), 1: (1,), 2: (1,), 3: (1,), 4: (1,), 5: (1,), 6: (1,), 7: (1,), 8: (1,), 9: (1,), 10: (1,), 11: (1,)}, 'y': {0: (1,), 1: (1,), 2: (1,)}}

Run Experiment Pipeline

Now that we have defined the IMGEP internal models, we can run the IMGEP experimental pipeline. As in the previous tutorial, we compare it with a random exploration strategy given the same experimental budget of experiments. Here we define a total of N=5000 experiments, with a batch size of 100.

jax_platform_name = "cpu"
seed = 0

# Run IMGEP
n_random_batches = 10 
n_imgep_batches = 40
batch_size = 100
imgep_experiment_data_save_folder = "data/periodic_imgep_data"
if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
    run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                         imgep_experiment_data_save_folder,
                         random_intervention_generator, intervention_fn,
                         null_perturbation_generator, null_perturbation_fn,
                         system_rollout, null_rollout_statistics_encoder,
                         goal_generator, gc_intervention_selector, gc_intervention_optimizer,
                         goal_embedding_encoder, goal_achievement_loss,
                         out_sanity_check=False, save_modules=False, save_logs=False)

# Run Random Search
rs_experiment_data_save_folder = "data/periodic_rs_data"
if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
    run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size, 
                      rs_experiment_data_save_folder,
                      random_intervention_generator, intervention_fn,
                      null_perturbation_generator, null_perturbation_fn,
                      system_rollout, null_rollout_statistics_encoder,
                      out_sanity_check=False, save_modules=False, save_logs=False)

Analysis of the discoveries

imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
print(imgep_reached_goals_embeddings.shape)

rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle")) 
key, *subkeys = jrandom.split(key, num=len(imgep_reached_goals_embeddings)+ 1)
rs_reached_goals_embeddings, _ = vmap(goal_embedding_encoder)(jnp.array(subkeys), rs_experiment_history.system_output_library)
print(rs_reached_goals_embeddings.shape)
(5000, 2501)
(5000, 2501)

Number of discovered oscillators

rs_is_periodic_bool, rs_offset_vals, rs_ampl_vals, rs_freq_vals = is_periodic(rs_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
rs_is_periodic_ids = jnp.where(rs_is_periodic_bool)[0]
print(f"Random Search has discovered {len(rs_is_periodic_ids)} oscillator circuits out of N={len(rs_is_periodic_bool)} trials")

imgep_is_periodic_bool, imgep_offset_vals, imgep_ampl_vals, imgep_freq_vals = is_periodic(imgep_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
imgep_is_periodic_ids = jnp.where(imgep_is_periodic_bool)[0]
print(f"Curiosity Search has discovered {len(imgep_is_periodic_ids)} oscillator circuits out of N={len(imgep_is_periodic_bool)} trials")
Random Search has discovered 42 oscillator circuits out of N=5000 trials
Curiosity Search has discovered 1167 oscillator circuits out of N=5000 trials

Diversity of discovered oscillators

Here, the analytic BC space is the space of (amplitude $A$, main frequency $\omega$, offset $b$) of the discovered oscillators, where $(A,\omega,b)$ are estimated by Autodiscjax is_periodic util function. Diversity is measured with the QD-score, a binning-based metric where the BC space is discretized into a collection of bins and the diversity is quantified as the number of bins filled over the course of exploration. We opt for a regular binning where each dimension of the BC space is discretized into equally sized bins, using 20 bins per dimension. We do not use the threshold-coverage metric as in tutorial 1 as it is difficult to compute in n-dimensional spaces where $n\ge3$.

imgep_reached_goals_embeddings = jnp.stack([imgep_offset_vals.at[~imgep_is_periodic_bool].set(0.0), 
                                            imgep_ampl_vals.at[~imgep_is_periodic_bool].set(0.0), 
                                            imgep_freq_vals.at[~imgep_is_periodic_bool].set(0.0)], -1)
rs_reached_goals_embeddings = jnp.stack([rs_offset_vals.at[~rs_is_periodic_bool].set(0.0), 
                                         rs_ampl_vals.at[~rs_is_periodic_bool].set(0.0), 
                                         rs_freq_vals.at[~rs_is_periodic_bool].set(0.0)], -1)

analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))
analytic_bc_space_extent = jnp.stack([analytic_bc_space_low, analytic_bc_space_high]).transpose()

def calc_analytic_bc_coverage_histograms(reached_goals_embeddings, analytic_bc_space_extent, n_bins=20, every_n_steps=1):
    def f(carry, goal_embedding):
        Hf = carry
        cur_Hf, _ = jnp.histogramdd(goal_embedding[jnp.newaxis], bins=n_bins,range=analytic_bc_space_extent)
        Hf = Hf + cur_Hf.transpose()
        return Hf, Hf

    final_coverage_histogram, coverage_histograms = lax.scan(f, jnp.zeros((n_bins, n_bins, n_bins), dtype=jnp.int32), reached_goals_embeddings[::every_n_steps])
    return coverage_histograms


imgep_coverage_histograms = calc_analytic_bc_coverage_histograms(imgep_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)
rs_coverage_histograms = calc_analytic_bc_coverage_histograms(rs_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)
Figure 3: Characteristics of the discovered oscillators.

👉 We can see that, one again, curiosity search is much more efficient than random search in revealing a diversity of possible oscillator behaviors. Given the same experimental budget of 5000 model rollouts, random search was able to find only 42 configurations leading to periodic patterns whereas curiosity search was able to find 1167. Projecting the discoveries into the space of (amplitude, frequency, offset), we can see that curiosity search efficiently reveals and covers the reachable space (a-c), reaching hard-to-discover behaviors on the borders of the space (d-i).

Part 2: Curiosity search as an alternative to pure optimization-driven search strategy?

In Tom W. Hiscock's paper, it is showcased how the use of gradient descent-based optimization can be helpful to design (synthetic) gene circuits with desired functionalities, and the example of optimizing the transcriptional gene circuit parameters to generate oscillations with desired (amplitude $A$, main frequency $\omega$) is considered. In the paper, the loss function is defined as $C = \sum_t(y_i(t) - (A \cos(2\pi\omega t)+b))^2$ where $y_i$ is the observed node (here i=0). Adam optimizer is then used with parameters $lr = 0.1, b1 = 0.02, b2 = 0.001$. Here we use the same parameters except for the learning rate that is chosen as $lr=1e-3$ (0.1 too big here). Note that $b$ is not optimized in the original paper and considered fixed as $b=0$ but this leads to biologically not-admissible target with negative gene expression levels. Here, we consider targets respecting the plausible gene expression levels $0 \le y \le 1$ in the gene circuit model. We define $A \in [0.1,05],b \in [A,1-A], w \in [0,1]$.

Model and Loss definition

# Model Rollout
class ModelRollout(eqx.Module):
    deltaT: float
    y0: Array
    c: Array
    grn_step: SimpleModelStep

    def __init__(self, deltaT, y0, c, grn_step):
        super().__init__()
        self.deltaT = deltaT
        self.y0 = jnp.maximum(y0, 0.)
        self.c = c
        self.grn_step = grn_step

    @partial(jit, static_argnames=("n_steps",))
    def __call__(self, n_steps):
        def f(carry, x):
            y, w, c, t = carry
            return self.grn_step(y, w, c, t, self.deltaT), (y, w, t)
        (y, w, c, t), (ys, ws, ts) = lax.scan(f, (self.y0, jnp.array([]), self.c, 0.0), jnp.arange(n_steps))
        ys = jnp.moveaxis(ys, 0, -1)
        ws = jnp.moveaxis(ws, 0, -1)
        return ys, ws, ts
def loss_pattern(ys, A, b, w):
    target_ys = A*jnp.cos(2*jnp.pi*w*random_system_outputs.ts)+b
    loss = jnp.sqrt(jnp.square(ys-target_ys).sum())
    return loss

Stochastic Gradient Descent

We use Adam optimizer as in Hiscock et al [1], and define the functions loss_fn and make_step for the optax pipeline.

# Optax optimizer, loss function and update function
optim = optax.adam(1e-3, b1=0.02, b2=0.001) # Same optimizer params than in Hiscock et al.

@jit
def loss_fn(params, A, b, w): 
    """loss function"""
    y0, c = params
    model = ModelRollout(deltaT, y0, c, SimpleModelStep())
    ys, ws, ts = model(n_steps)
    loss = loss_pattern(ys[0], A, b, w)
    return loss

@jit
def make_step(params, A, b, w, opt_state):
    """update function"""
    loss, grads = value_and_grad(loss_fn)(params, A, b, w)
    updates, opt_state = optim.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

Evolutionary Algorithm

We also compare with an evolutionary search strategy known as CMA-ES which is supposed to be self-adaptive in the self-size, and often used for optimization problems.

The CMA-ES code we use is taken from the evojax library, and we refer the reader to their original codebase available at: https://github.com/google/evojax/blob/main/evojax/algo/cma_jax.py.

For the fitness evaluation, we simply use the negative of the previously-defined loss.

Run optimization pipelines

We consider 3 setups: 1. Giving gradient-descent a budget of N=5000 optimization steps (same number of model rollouts allowed than for curiosity search and random search) and starting from random init $y_0 \in [0,1]^{n}, W \in [-30,30]^{n \times n}, B \in[-10,10]^{n}$ 2. Giving CMA-ES a budget of N=5000 optimization steps (same number of model rollouts allowed than for curiosity search and random search) and starting from random init $y_0 \in [0,1]^{n}, W \in [-30,30]^{n \times n}, B \in[-10,10]^{n}$ 2. Giving gradient-descent a budget of N=100 optimization steps and starting from the best discoveries made by the curiosity search and random search exploration strategies (small budget this time, for local refinement).

# Generate RANDOM Target
A = jrandom.uniform(subkey, minval=0.1, maxval=0.5)
key, subkey = jrandom.split(key)
w = jrandom.beta(subkey, a=2, b=8)
key, subkey = jrandom.split(key)
b = jrandom.uniform(subkey, minval=A, maxval=1-A)

# Optax pipeline from RANDOM Init
key, subkey = jrandom.split(key)
y0_sgd = jrandom.uniform(subkey, shape=(n, ), minval=y0_low, maxval=y0_high)
key, subkey = jrandom.split(key)
c_sgd = jrandom.uniform(subkey, shape=(n**2+n, ), minval=c_low, maxval=c_high)

model = ModelRollout(deltaT, y0_sgd, c_sgd, SimpleModelStep())
ys_random, _, _ = model(n_steps)

n_optim_steps = 5000
opt_state = optim.init((y0_sgd, c_sgd))
loss_sgd = []
n_sgd_oscillators = 0
for optim_step_idx in range(n_optim_steps):
    loss, (y0_sgd, c_sgd), opt_state = make_step((y0_sgd, c_sgd), A, b, w, opt_state)
    loss_sgd.append(loss)

    # check whether gradient descent passes through some oscillator behaviors
    model = ModelRollout(deltaT, y0_sgd, c_sgd, SimpleModelStep())
    ys_sgd, _, _ = model(n_steps)
    is_periodic_bool, _, _, _ = is_periodic(ys_sgd[0,:], jnp.r_[-system_rollout.n_steps//2:0], deltaT, 40)
    n_sgd_oscillators += int(is_periodic_bool)

print(f"Gradient-descent optimization has discovered {n_sgd_oscillators} oscillator circuits out of N={n_optim_steps} trials")


# Optax pipeline from closest init in IMGEP discoveries
loss_imgep = vmap(loss_pattern, in_axes=(0,None,None,None))(imgep_experiment_history.system_output_library.ys[:,0], A, b, w)
imgep_best_idx = imgep_is_periodic_ids[loss_imgep[imgep_is_periodic_ids].argmin()]
y0_imgep = jnp.array([imgep_experiment_history.intervention_params_library.y[node_idx][imgep_best_idx, 0] for node_idx in range(n)])
c_imgep = jnp.array([imgep_experiment_history.intervention_params_library.c[param_idx][imgep_best_idx, 0] for param_idx in range(n**2+n)])

model = ModelRollout(deltaT, y0_imgep, c_imgep, SimpleModelStep())
ys_imgep, _, _ = model(n_steps)

n_optim_steps = 100
opt_state = optim.init((y0_imgep, c_imgep))
loss_imgep_sgd = []
for optim_step_idx in range(n_optim_steps):
    loss, (y0_imgep, c_imgep), opt_state = make_step((y0_imgep, c_imgep), A, b, w, opt_state)
    loss_imgep_sgd.append(loss)

model = ModelRollout(deltaT, y0_imgep, c_imgep, SimpleModelStep())
ys_imgep_sgd, _, _ = model(n_steps)

## arrange loss prior optim for plotting
cur_min = loss_imgep[0]
for i, cur_loss in enumerate(loss_imgep):
    if cur_loss > cur_min:
        loss_imgep = loss_imgep.at[i].set(cur_min)
    else:
        cur_min = cur_loss

# Optax pipeline from closest init in RS discoveries
loss_rs = vmap(loss_pattern, in_axes=(0,None,None,None))(rs_experiment_history.system_output_library.ys[:,0], A, b, w)
rs_best_idx = rs_is_periodic_ids[loss_rs[rs_is_periodic_ids].argmin()]
y0_rs = jnp.array([rs_experiment_history.intervention_params_library.y[node_idx][rs_best_idx, 0] for node_idx in range(n)])
c_rs = jnp.array([rs_experiment_history.intervention_params_library.c[param_idx][rs_best_idx, 0] for param_idx in range(n**2+n)])

model = ModelRollout(deltaT, y0_rs, c_rs, SimpleModelStep())
ys_rs, _, _ = model(n_steps)

n_optim_steps = 100
opt_state = optim.init((y0_rs, c_rs))
loss_rs_sgd = []
for optim_step_idx in range(n_optim_steps):
    loss, (y0_rs, c_rs), opt_state = make_step((y0_rs, c_rs), A, b, w, opt_state)
    loss_rs_sgd.append(loss)

model = ModelRollout(deltaT, y0_rs, c_rs, SimpleModelStep())
ys_rs_sgd, _, _ = model(n_steps)

## arrange random search loss prior optim for plotting
cur_min = loss_rs[0]
for i, cur_loss in enumerate(loss_rs):
    if cur_loss > cur_min:
        loss_rs = loss_rs.at[i].set(cur_min)
    else:
        cur_min = cur_loss 


# CMA-ES pipeline from RANDOM Inits
key, subkey = jrandom.split(key)
y0_cma = jrandom.uniform(subkey, shape=(n, ), minval=y0_low, maxval=y0_high)
key, subkey = jrandom.split(key)
c_cma = jrandom.uniform(subkey, shape=(n**2+n, ), minval=c_low, maxval=c_high)

solver = CMA_ES_JAX(param_size=15, mean=jnp.concatenate([y0_sgd, c_cma]), init_stdev=2, seed=0)
n_optim_steps = int(jnp.ceil(5000 / solver.pop_size).item())
loss_cma = []
n_cma_oscillators = 0
for optim_step_idx in range(n_optim_steps):
    params = solver.ask()
    y0_cma = params[:, :3]
    c_cma = params[:, 3:]
    losses = vmap(loss_fn, (0, None, None, None))((y0_cma, c_cma), A, b, w)
    solver.tell(-losses)
    loss_cma += losses.tolist()


    # check whether gradient descent passes through some oscillator behaviors
    for cur_y0_cma, cur_c_cma in zip(y0_cma, c_cma):
        model = ModelRollout(deltaT, cur_y0_cma, cur_c_cma, SimpleModelStep())
        ys_cma, _, _ = model(n_steps)
        is_periodic_bool, _, _, _ = is_periodic(ys_cma[0,:], jnp.r_[-system_rollout.n_steps//2:0], deltaT, 40)
        n_cma_oscillators += int(is_periodic_bool)

print(f"CMA-ES optimization has discovered {n_cma_oscillators} oscillator circuits out of N={n_optim_steps} trials")

y0_cma = solver.best_params[:3]
c_cma = solver.best_params[3:]
model = ModelRollout(deltaT, y0_cma, c_cma, SimpleModelStep())
ys_cma, _, _ = model(n_steps)
Gradient-descent optimization has discovered 0 oscillator circuits out of N=5000 trials
CMA-ES optimization has discovered 1 oscillator circuits out of N=417 trials

Analysis of the discoveries

Figure 4: Comparison of four alternative strategies for the design of oscillator circuits: curiosity search (blue), random search (pink), gradient descent (orange) and an evolutionary algorithm (green). (a-c) Given a budget of 5000 experiments, curiosity search is able to find 1167 oscillator circuits (ones showing sustained oscillations), whereas random search only finds 42 oscillators, and optimization-driven search fail to discover them (only one discovered by CMA-ES and none for gradient descent when starting from a random initialization). (a) 3D scatter plot of the 42 random search discoveries (pink) and 1167 curiosity search ones (blue) in the (amplitude, main frequency, offset) analytic behavior space. (b) Box plots projecting points from the 3D scatter plot into the respective (amplitude, main frequency, offset) axes. (c) Diversity discovered throughout exploration, where diversity is measured with a binning-based space coverage metric (20 bins per dimension). (d-e-f-g) Best discoveries (for which $L$ is minimal) made by the four exploration strategies. (h) Evolution of the optimization loss $L$ for the four algorithm variants. (i) Evolution of local training loss when finetuning the best IMGEP (blue) and RS (pink) discoveries with gradient descent, with the finetuned results displayed in (j-k).

👉 We can see that gradient descent alone fails to discover an oscillator in this example, as it get trap in a strong local minima (constant signal with same average than the target oscillator). This shows the challenge of finding a proper loss and/or parameter initialization. While CMA-ES has a more exploratory behavior at the beginning, it also get stucks in a similar local minima than gradient descent, showcasing the limitations of pure optimization-driven strategies failing to properly explore the space of solutions.

However, we found that optimization strategies can be useful for locally finetuning "close-enough" discoveries, such as the ones discovered by the curiosity search and/or random search. Note that in this example, curiosity search reaches a better solution than random search and in a more efficient way with (N=1057+100, L=13.97) versus (N=2663+100, L=22.92) for random search.