Source code for sbmltoodejax.modulegeneration

import jax.numpy as jnp
from sbmltoodejax import jaxfuncs
import re
import sys

[docs] def GenerateModel(modelData, outputFilePath, RateofSpeciesChangeName: str ='RateofSpeciesChange', AssignmentRuleName: str='AssignmentRule', ModelStepName: str='ModelStep', ModelRolloutName: str='ModelRollout', vary_constant_reactants: bool=False, vary_boundary_reactants: bool=False, deltaT: float =0.1, atol: float=1e-6, rtol: float = 1e-12, mxstep: int = 5000000 ): """ This function takes model data created by :func:`~sbmltoodejax.parse.ParseSBMLFile` and generates a python file containing variables and modules that implement the SBML model. Note: This function is adapted from ``sbmltoodepy.modulegeneration.GenerateModel`` function, however the generated python file is written in JAX and follows different conventions for the generated variables and modules, as detailed in :ref:`structure-of-the-generated-python-file` Args: modelData (sbmltoodepy.dataclasses.ModelData): An object containing all the model components and values. outputFilePath (str): The desired file path of the resulting python file. RateofSpeciesChangeName (str, optional): The name of the RateofSpeciesChange module defined in the resulting python file. Default to 'RateofSpeciesChange'. AssignmentRuleName (str): The name of the AssignmentRule module defined in the resulting python file. Default to 'AssignmentRule'. ModelStepName (str): The name of the ModelStep module defined in the resulting python file. Default to 'ModelStep'. ModelRolloutName (str): The name of the ModelRollout module defined in the resulting python file. Default to'ModelRollout'. deltaT (float): Time step size (in seconds). Default to 0.1. atol (float): Absolute local error tolerance for ``jax.experimental.odeint`` solver. Default to 1e-6. rtol (float): Relative local error tolerance for ``jax.experimental.odeint`` solver. Default to 1e-12. mxstep (int): Maximum number of steps to take for each timepoint for ``jax.experimental.odeint`` solver. Default to 5000000. """ jnp.set_printoptions(threshold=sys.maxsize) outputFile = open(outputFilePath, "w") parameters = modelData.parameters compartments = modelData.compartments for k, v in compartments.items(): if not v.isConstant: raise NotImplementedError("Varying compartment size is not handled") species = modelData.species reactions = modelData.reactions functions = modelData.functions assignmentRules = modelData.assignmentRules rateRules = modelData.rateRules initialAssignments = modelData.initialAssignments mathFuncs = {'abs': 'jnp.abs', 'max': 'jnp.max', 'min': 'jnp.min', 'pow': 'jnp.power', 'exp': 'jnp.exp', 'floor': 'jnp.floor', 'ceiling': 'jnp.ceil', 'ln': 'jnp.log', 'log': 'jnp.log10', 'factorial': 'jaxfuncs.factorial', 'sqrt': 'no.sqrt', 'eq': 'jnp.equal', 'neq': 'jnp.not_equal', 'gt': 'jnp.greater', 'lt': 'jnp.less', 'geq': 'jnp.greater_equal', 'leq': 'jnp.less_equal', 'and': 'and', 'or': 'or', 'xor': 'jnp.logical_xor', 'not': 'not', 'sin': 'jnp.sin', 'cos': 'jnp.cos', 'tan': 'jnp.tan', 'sec': 'jaxfuncs.sec', 'csc': 'jaxfuncs.csc', 'cot': 'jaxfuncs.cot', 'sinh': 'jnp.sinh', 'cosh': 'jnp.cosh', 'tanh': 'jnp.tanh', 'sech': 'jaxfuncs.sech', 'csch': 'jaxfuncs.csch', 'coth': 'jaxfuncs.coth', 'arcsin': 'jnp.arcsin', 'arccos': 'jnp.arccos', 'arctan': 'jnp.arctan', 'arcsinh': 'jnp.arcsinh', 'arccosh': 'jnp.arccosh', 'arctanh': 'jnp.arctanh', 'true': 'True', 'false': 'False', 'notanumber': 'jnp.nan', 'pi': 'jnp.pi', 'infinity': 'jnp.inf', 'exponentiale': 'jnp.e', 'piecewise': 'jaxfuncs.piecewise' } # TODO: Add in user defined functions # ================================================================================================================================ outputFile.write("import equinox as eqx\n") outputFile.write("from functools import partial\n") outputFile.write("from jax import jit, lax, vmap\n") outputFile.write("from jax.experimental.ode import odeint\n") outputFile.write("import jax.numpy as jnp\n\n") outputFile.write("from sbmltoodejax import jaxfuncs\n\n") # ================================================================================================================================ t0 = 0.0 y0 = [] y_indexes = {} for reaction_name, reaction in reactions.items(): reactants = [reactant for (reactantCoeff, reactant) in reaction.reactants] for reactant in reactants: if reactant not in y_indexes: if reactant in species: if species[reactant].valueType == "Amount": y_amount = species[reactant].value elif species[reactant].valueType == "Concentration": y_amount = species[reactant].value * compartments[species[reactant].compartment].size else: raise ValueError("Specie value is not of type amount nor concentration") else: raise NotImplementedError("Reactant is not a specie") if not species[reactant].isConstant or vary_constant_reactants: y0.append(y_amount) y_indexes[reactant] = len(y0) - 1 for rule_name, rule in rateRules.items(): if rule.variable not in y_indexes: if rule.variable in species: if species[rule.variable].valueType == "Amount": y_amount = species[rule.variable].value elif species[rule.variable].valueType == "Concentration": y_amount = species[rule.variable].value * compartments[species[rule.variable].compartment].size else: raise ValueError("Specie value is not of type amount nor concentration") elif rule.variable in parameters: y_amount = parameters[rule.variable].value elif rule.variable in compartments: raise NotImplementedError("Varying compartment size is not handled") y0.append(y_amount) y_indexes[rule.variable] = len(y0) - 1 w0 = [] w_indexes = {} for k, v in assignmentRules.items(): if v.variable in species: if species[v.variable].valueType == "Amount": w_amount = species[v.variable].value elif species[v.variable].valueType == "Concentration": w_amount = species[v.variable].value * compartments[species[v.variable].compartment].size else: raise ValueError("Specie value is not of type amount nor concentration") elif v.variable in parameters: w_amount = parameters[v.variable].value elif v.variable in compartments: w_amount = compartments[v.variable].size w0.append(w_amount) w_indexes[v.variable] = len(w0) - 1 # Add all parameters that are not in w_indexes or y_indexes c = [] c_indexes = {} for k, v in species.items(): if (k not in y_indexes) and (k not in w_indexes): if v.valueType == "Amount": c_amount = v.value elif v.valueType == "Concentration": c_amount = v.value * compartments[v.compartment].size else: raise ValueError("Specie value is not of type amount nor concentration") c.append(c_amount) c_indexes[k] = len(c) - 1 for k, v in parameters.items(): if (k not in y_indexes) and (k not in w_indexes): c.append(parameters[k].value) c_indexes[k] = len(c) - 1 for k, v in compartments.items(): if (k not in y_indexes) and (k not in w_indexes): c.append(compartments[k].size) c_indexes[k] = len(c) - 1 # Add constant parameters that are defined in reactions or that for some reason are not written as "constant" in their definition for reaction_name, reaction in reactions.items(): for param in reaction.rxnParameters: param_name = reaction_name + "_" + param[0] assert (param_name not in y_indexes) and (param_name not in w_indexes) if (param_name not in c_indexes): c.append(param[1]) c_indexes[param_name] = len(c) - 1 # ================================================================================================================================ def ParseLHS(rawLHS): assert rawLHS in w_indexes returnLHS = f"w = w.at[{w_indexes[rawLHS]}].set(" if rawLHS in species: if not species[rawLHS].hasOnlySubstanceUnits: returnLHS += f"{compartments[species[rawLHS].compartment].size} * " returnLHS += "(" return returnLHS def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y", wvar="w", cvar="c", tvar="t"): # The main purpose of this function is to turn math strings given by libSBML into code formated to properly call self.c, self.w and y rawRateLaw = rawRateLaw.replace("^", "**").replace("&&", "&").replace("||", "|") # Replace not understood operators variables = [] for match in re.finditer(r'\b[a-zA-Z_]\w*', rawRateLaw): # look for variable names # ToDo: check for function calls variables.append([rawRateLaw[match.start():match.end()], match.span()]) returnRHS = '' oldSpan = None if variables != []: for variable in variables: if oldSpan == None and variable[1][0] != 0: returnRHS += rawRateLaw[0:variable[1][0]] elif oldSpan != None: returnRHS += rawRateLaw[oldSpan[1]:variable[1][0]] oldSpan = variable[1] if variable[0] in extended_param_names and reaction_name is not None: variable[0] = reaction_name + "_" + variable[0] if variable[0] in species and not species[variable[0]].hasOnlySubstanceUnits: returnRHS += '(' if variable[0] in c_indexes: returnRHS += f'{cvar}[{c_indexes[variable[0]]}]' elif variable[0] in w_indexes: returnRHS += f'{wvar}[{w_indexes[variable[0]]}]' elif variable[0] in y_indexes: returnRHS += f'{yvar}[{y_indexes[variable[0]]}]' elif variable[0] in mathFuncs: returnRHS += mathFuncs[variable[0]] elif variable[0] in functions: raise NotImplementedError("Custom functions are not handled") elif variable[0] == "time": returnRHS += f'{tvar}' elif variable[0] == "pi": returnRHS += "jnp.pi" else: raise (Exception('New case: unkown Reaction variable: ' + variable[0])) if variable[0] in species and not species[variable[0]].hasOnlySubstanceUnits: returnRHS += f'/{compartments[species[variable[0]].compartment].size})' returnRHS += rawRateLaw[variable[1][1]:len(rawRateLaw)] else: returnRHS = rawRateLaw return returnRHS # ================================================================================================================================ ruleDefinedVars = [rule.variable for rule in assignmentRules.values()] for key, assignment in initialAssignments.items(): ruleDefinedVars.append(assignment.variable) for key, rule in assignmentRules.items(): rule.dependents = [] for match in re.finditer(r'\b[a-zA-Z_]\w*', rule.math): # look for variable names rule.dependents.append(rule.math[match.start():match.end()]) originalLen = len(rule.dependents) for i in range(originalLen): if rule.dependents[originalLen - i - 1] not in ruleDefinedVars: rule.dependents.pop(originalLen - i - 1) for key, assignment in initialAssignments.items(): assignment.dependents = [] for match in re.finditer(r'\b[a-zA-Z_]\w*', assignment.math): # look for variable names assignment.dependents.append(assignment.math[match.start():match.end()]) originalLen = len(assignment.dependents) for i in range(originalLen): if assignment.dependents[originalLen - i - 1] not in ruleDefinedVars: assignment.dependents.pop(originalLen - i - 1) while True: continueVar = False breakVar = True varDefinedThisLoop = None for key, rule in assignmentRules.items(): if rule.dependents == []: var_amount = eval(ParseRHS(rule.math, yvar="y0", wvar="w0", cvar="c", tvar="t0")) if rule.variable in species and not species[rule.variable].hasOnlySubstanceUnits: var_amount *= compartments[species[rule.variable].compartment].size if isinstance(var_amount, jnp.ndarray): var_amount = var_amount.item() if rule.variable in y_indexes: y0[y_indexes[rule.variable]] = var_amount elif rule.variable in w_indexes: w0[w_indexes[rule.variable]] = var_amount else: raise ValueError("Rule variable is not in y nor w") varDefinedThisLoop = rule.variable rule.dependents = None continueVar = True breakVar = False break elif not rule.dependents == None: breakVar = False if not continueVar: for key, assignment in initialAssignments.items(): if assignment.dependents == []: var_amount = eval(ParseRHS(assignment.math, yvar="y0", wvar="w0", cvar="c")) if assignment.variable in species and not ( species[assignment.variable].hasOnlySubstanceUnits): var_amount *= compartments[species[assignment.variable].compartment].size if isinstance(var_amount, jnp.ndarray): var_amount = var_amount.item() if assignment.variable in y_indexes: y0[y_indexes[assignment.variable]] = var_amount elif assignment.variable in w_indexes: w0[w_indexes[assignment.variable]] = var_amount elif assignment.variable in c_indexes: c[c_indexes[assignment.variable]] = var_amount else: raise ValueError("Assignment variable is not in y, w nor c") varDefinedThisLoop = assignment.variable assignment.dependents = None continueVar = True breakVar = False break elif not assignment.dependents == None: breakVar = False for rule in assignmentRules.values(): if not rule.dependents == None: originalLen = len(rule.dependents) for i in range(originalLen): if rule.dependents[originalLen - i - 1] == varDefinedThisLoop: rule.dependents.pop(originalLen - i - 1) for assignment in initialAssignments.values(): if not assignment.dependents == None: originalLen = len(assignment.dependents) for i in range(originalLen): if assignment.dependents[originalLen - i - 1] == varDefinedThisLoop: assignment.dependents.pop(originalLen - i - 1) if continueVar: continue elif breakVar: break else: raise Exception('Algebraic Loop in AssignmentRules') # ================================================================================================================================ outputFile.write(f"t0 = {t0}\n\n") outputFile.write(f"y0 = jnp.array({y0})\n") outputFile.write(f"y_indexes = {y_indexes}\n\n") outputFile.write(f"w0 = jnp.array({w0})\n") outputFile.write(f"w_indexes = {w_indexes}\n\n") outputFile.write(f"c = jnp.array({c}) \n") outputFile.write(f"c_indexes = {c_indexes}\n\n") # ================================================================================================================================ # Set up stoichCoeffMat, a matrix of stoichiometric coefficients for solving the reactions reactionCounter = 0 reactionIndex = {} stoichCoeffMat = jnp.zeros([len(y_indexes), max(len(reactions), 1)]) for rxnId in reactions: reactionIndex[rxnId] = reactionCounter reactionCounter += 1 reaction = reactions[rxnId] for reactant in reaction.reactants: if (reactant[1] in y_indexes): if (not species[reactant[1]].isBoundarySpecies) or vary_boundary_reactants: stoichCoeffMat = stoichCoeffMat.at[y_indexes[reactant[1]], reactionIndex[rxnId]].add(reactant[0]) rateArray = ['0.0'] * len(y_indexes) for rule_name, rule in rateRules.items(): rateArray[y_indexes[rule.variable]] = 'self.Rate' + rule.variable + '(y, w, c, t)' # Write outputFile.write("class " + RateofSpeciesChangeName + "(eqx.Module):\n") outputFile.write(f"\tstoichiometricMatrix = jnp.array({str(stoichCoeffMat.tolist())}, dtype=jnp.float32) \n\n") outputFile.write("\t@jit\n") outputFile.write("\tdef __call__(self, y, t, w, c):\n") outputFile.write('\t\trateRuleVector = jnp.array([' + ', '.join(var for var in rateArray) + '], dtype=jnp.float32)\n\n') outputFile.write('\t\treactionVelocities = self.calc_reaction_velocities(y, w, c, t)\n\n') outputFile.write('\t\trateOfSpeciesChange = self.stoichiometricMatrix @ reactionVelocities + rateRuleVector\n\n') outputFile.write('\t\treturn rateOfSpeciesChange\n\n') outputFile.write(f'\n\tdef calc_reaction_velocities(self, y, w, c, t):\n') reactionElements = '' outputFile.write('\t\treactionVelocities = jnp.array([') if reactions: for reactionId in reactions: if reactionElements == '': reactionElements += ('self.' + str(reactionId) + '(y, w, c, t)') else: reactionElements += (', self.' + str(reactionId) + '(y, w, c, t)') else: reactionElements = '0' outputFile.write(reactionElements + '], dtype=jnp.float32)\n\n') outputFile.write('\t\treturn reactionVelocities\n\n') for reaction_name in reactions.keys(): outputFile.write(f'\n\tdef {reaction_name}(self, y, w, c, t):\n') rxnParamNames = [param[0] for param in reactions[reaction_name].rxnParameters] rateLaw = ParseRHS(reactions[reaction_name].rateLaw, extended_param_names=rxnParamNames, reaction_name=reaction_name, yvar="y", wvar="w", cvar="c") outputFile.write('\t\treturn ' + rateLaw + '\n\n') for key, rateRule in rateRules.items(): outputFile.write("\tdef Rate" + rateRule.variable + "(self, y, w, c, t):\n") rateLaw = ParseRHS(rateRule.math, yvar="y", wvar="w", cvar="c") outputFile.write('\t\treturn ' + rateLaw + '\n\n') # ================================================================================================================================ outputFile.write("class " + AssignmentRuleName + "(eqx.Module):\n") outputFile.write("\t@jit\n") outputFile.write("\tdef __call__(self, y, w, c, t):\n") ruleDefinedVars = [rule.variable for rule in assignmentRules.values()] for key, rule in assignmentRules.items(): rule.dependents = [] for match in re.finditer(r'\b[a-zA-Z_]\w*', rule.math): # look for variable names rule.dependents.append(rule.math[match.start():match.end()]) originalLen = len(rule.dependents) for i in range(originalLen): if rule.dependents[originalLen - i - 1] not in ruleDefinedVars: rule.dependents.pop(originalLen - i - 1) while True: continueVar = False breakVar = True varDefinedThisLoop = None for key, rule in assignmentRules.items(): if rule.dependents == []: ruleLHS = ParseLHS(rule.variable) ruleRHS = ParseRHS(rule.math, yvar="y", wvar="w", cvar="c") outputFile.write("\t\t" + ruleLHS + ruleRHS + '))\n\n') varDefinedThisLoop = rule.variable rule.dependents = None continueVar = True breakVar = False break elif not rule.dependents == None: breakVar = False for rule in assignmentRules.values(): if not rule.dependents == None: originalLen = len(rule.dependents) for i in range(originalLen): if rule.dependents[originalLen - i - 1] == varDefinedThisLoop: rule.dependents.pop(originalLen - i - 1) if continueVar: continue elif breakVar: break else: raise Exception('Algebraic Loop in AssignmentRules') outputFile.write("\t\treturn w\n\n") # ================================================================================================================================ outputFile.write("class " + ModelStepName + "(eqx.Module):\n") outputFile.write("\ty_indexes: dict = eqx.static_field()\n") outputFile.write("\tw_indexes: dict = eqx.static_field()\n") outputFile.write("\tc_indexes: dict = eqx.static_field()\n") outputFile.write(f"\tratefunc: {RateofSpeciesChangeName}\n") outputFile.write("\tatol: float = eqx.static_field()\n") outputFile.write("\trtol: float = eqx.static_field()\n") outputFile.write("\tmxstep: int = eqx.static_field()\n") outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n\n") outputFile.write(f"\tdef __init__(self, " f"y_indexes={y_indexes}, " f"w_indexes={w_indexes}, " f"c_indexes={c_indexes}, " f"atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n") outputFile.write("\t\tself.y_indexes = y_indexes\n") outputFile.write("\t\tself.w_indexes = w_indexes\n") outputFile.write("\t\tself.c_indexes = c_indexes\n\n") outputFile.write(f"\t\tself.ratefunc = {RateofSpeciesChangeName}()\n") outputFile.write("\t\tself.rtol = rtol\n") outputFile.write("\t\tself.atol = atol\n") outputFile.write("\t\tself.mxstep = mxstep\n") outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n\n") outputFile.write("\t@jit\n") outputFile.write("\tdef __call__(self, y, w, c, t, deltaT):\n") outputFile.write("\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\t\n") outputFile.write("\t\tt_new = t + deltaT\t\n") outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\t\n") outputFile.write("\t\treturn y_new, w_new, c, t_new\t\n\n") # ================================================================================================================================ outputFile.write("class " + ModelRolloutName + "(eqx.Module):\n") outputFile.write("\tdeltaT: float = eqx.static_field()\n") outputFile.write(f"\tmodelstepfunc: {ModelStepName}\n\n") outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n") outputFile.write("\t\tself.deltaT = deltaT\n") outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep)\n\n") outputFile.write("\t@partial(jit, static_argnames=(\"n_steps\",))\n") outputFile.write("\tdef __call__(self, n_steps, " f"y0=jnp.array({y0}), " f"w0=jnp.array({w0}), " f"c=jnp.array({c}), " f"t0={t0}" f"):\n\n") outputFile.write("\t\t@jit\n") outputFile.write("\t\tdef f(carry, x):\n") outputFile.write("\t\t\ty, w, c, t = carry\n") outputFile.write("\t\t\treturn self.modelstepfunc(y, w, c, t, self.deltaT), (y, w, t)\n") outputFile.write("\t\t(y, w, c, t), (ys, ws, ts) = lax.scan(f, (y0, w0, c, t0), jnp.arange(n_steps))\n") outputFile.write("\t\tys = jnp.moveaxis(ys, 0, -1)\n") outputFile.write("\t\tws = jnp.moveaxis(ws, 0, -1)\n") outputFile.write("\t\treturn ys, ws, ts\n\n") # ================================================================================================================================ outputFile.close()