Tuto2 Tuto1 Paper

AI-driven Automated Discovery Tools for Synthetic Circuit Engineering

Authors Affiliation Published
Mayalen Etcheverry INRIA, Flowers team, Poietis September, 2023
Clément Moulin-Frier INRIA, Flowers team
Pierre-Yves Oudeyer INRIA, Flowers team
Michael Levin The Levin Lab, Tufts University Reproduce in Notebook

Introduction

TL;DR

This second tutorial accompanies our paper Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks, and more particularly the last section "Reuse of the framework as an alternative strategy to gene circuit engineering".

📝 How to follow this tutorial

ModelStep function

When simulating synthetic gene regulatory network, we typically assume one family of ODE equations. Here we use the transcriptional gene circuit model with a simple model step defined as:

$\frac{d{y}_i}{dt}=\phi \left({\sum}_j{W}_{ij}{y}_j+{B}_i\right)-{k}_i{y}_i$

Here, we use $k_i=1, W_{ij}\in[-30,30], B_{i}\in[-10,10]$ and with these parameters species concentrations are constrained in $y\in[0,1]$

@jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

class SimpleModelStep(eqx.Module):
    def __init__(self, **kwargs):
        super().__init__()

    @jit
    def __call__(self, y, w, c, t, deltaT):

        n = len(y)
        W = c[:n * n].reshape((n, n))
        B = c[n * n:(n + 1) * n]
        y_new = y + deltaT * (sigmoid(W @ y + B ) - y)
        t_new = t + deltaT
        w_new = w

        return y_new, w_new, c, t_new

Part 1: Curiosity-driven search for the discovery of diverse oscillator circuits

Experiment Pipeline and Modules

System Rollout module

Now that we have define the new ModelStep function, AutoDiscJax allows us to simulate system rollout (and applying different kind of interventions on it) in the same manner that we did for biological networks in the first tutorial.

Let's instantitate the system rollout module.

n = 3 #number of nodes
deltaT = 0.01
n_secs = 100
n_steps = int(n_secs/deltaT)

c = jnp.empty(((n + 1) * n, ))
c_low = jnp.array([-30.]*n**2 + [-10.]*n)
c_high = jnp.array([30.]*n**2 + [10.]*n)
grn_step=SimpleModelStep()

y0=jnp.empty(shape=(n,))
y0_low = 0.
y0_high = 1.

w0 = jnp.array([])

system_rollout = grn.GRNRollout(n_steps=n_steps, y0=y0, w0=w0, c=c, t0=0.0, deltaT=deltaT, grn_step=grn_step)

Random Intervention Generator

Let's now use intervention to (randomly) set the GRN's init state (y0) and kinematic parameters (c)

# Create an intervention generator and an intervention_fn modules to set the initial state and the kinematic parameters to random values
random_intervention_generator_config = Dict()
random_intervention_generator_config.intervention_type = "set_uniform"
random_intervention_generator_config.controlled_intervals = [[0, deltaT/2.0]]

intervention_params_tree = DictTree()
intervention_params_low = DictTree()
intervention_params_high = DictTree()
for y_idx in range(len(y0)):
    intervention_params_tree.y[y_idx] = "placeholder"
    intervention_params_low.y[y_idx] = y0_low
    intervention_params_high.y[y_idx] = y0_high
for c_idx in range(len(c)):
    intervention_params_tree.c[c_idx] = "placeholder"
    intervention_params_low.c[c_idx] = c_low[c_idx]
    intervention_params_high.c[c_idx] = c_high[c_idx]

random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),), intervention_params_tree)
random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
random_intervention_generator_config.low = intervention_params_low
random_intervention_generator_config.high = intervention_params_high

random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)
# example: generate a random set of intervention parameters between low and high
key, subkey = jrandom.split(key)
intervention_params, log_data = random_intervention_generator(subkey)

# Run the system with the sample intervention
key, subkey = jrandom.split(key)
random_system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn, intervention_params=intervention_params)