jaxley.synapses.CurrentSynapse#
- class CurrentSynapse(nonlinearity=<PjitFunction of <function sigmoid>>, name=None)[source]#
Bases:
SynapseA current-based synapse.
The current of this synapse depends only on the pre-synaptic voltage.
This synapse implements the following equations:
\[I = \overline{g}\, \cdot \sigma\!\left( \frac{V_{\text{pre}} - V_{\text{thr}}}{\Delta} \right)\]where \(\mathrm{\sigma}(\cdot)\) is a nonlinearity such as a ReLU, Sigmoid, or TanH. By default, it is a sigmoid, but it can be modified by the user.
More informally: This synaptic current nonlinearly depends on the pre-synaptic voltage.
- The synaptic parameters are:
gS: the maximal conductance \(\overline{g}\) (uS).v_th: the threshold at which the synapse becomes active \(V_{\text{thr}}\) (mV).delta: The inverse of the slope of the activation \(\Delta\) (mV).
Example usage
Insert a synapse with a sigmoid nonlinearity (the default) and change parameters and initial state.
import jaxley as jx from jaxley.connect import connect from jaxley.synapses import CurrentSynapse cell = jx.Cell() net = jx.Network([cell for _ in range(2)]) # Connect neurons with the `CurrentSynapse`. connect(net.cell(0), net.cell(1), CurrentSynapse()) # Set parameters. net.set("CurrentSynapse_gS", 0.0001) # Maximal conductance. net.set("CurrentSynapse_v_th", -40.0) # Threshold. net.set("CurrentSynapse_delta", 10.0) # 1 / slope of activation.
Insert a synapse with a ReLU nonlinearity.
import jaxley as jx from jaxley.connect import connect from jaxley.synapses import CurrentSynapse from jax.nn import relu cell = jx.Cell() net = jx.Network([cell for _ in range(2)]) # Connect neurons with the `CurrentSynapse`. connect(net.cell(0), net.cell(1), CurrentSynapse(relu))
Insert a synapse with a custom nonlinearity.
import jaxley as jx from jaxley.connect import connect from jaxley.synapses import CurrentSynapse cell = jx.Cell() net = jx.Network([cell for _ in range(2)]) def nonlinearity(x): return x ** 2 # Connect neurons with the `CurrentSynapse`. connect(net.cell(0), net.cell(1), CurrentSynapse(nonlinearity))
- synapse_params = None#
- synapse_states = None#
- update_states(synapse_states, synapse_params, pre_voltage, post_voltage, pre_states, post_states, pre_params, post_params, delta_t)[source]#
Return updated synapse state and current.
- compute_current(synapse_states, synapse_params, pre_voltage, post_voltage, pre_states, post_states, pre_params, post_params, delta_t)[source]#
Return current through one synapse in nA.
Internally, we use jax.vmap to vectorize this function across many synapses.
- Parameters:
- Returns:
Current through the synapse in nA, shape ().
- Return type: