Spike counting simulation example

Bursting in excitable cells is the grouping of fast spikes into clusters, or bursts. It arises when the fast spiking dynamics are modulated by one or more slower variables. This example shows pseudo-plateau bursting in a simple model similar to a foundational model of pancreatic beta cells due to Teresa Chay and Joel Keizer. The model has two fast variables, voltage (\(v\)) and delayed-rectifier potassium channel activation (\(n\)), and one "slow" variable, calcium concentration (\(c\)). Here, \(c\) is not much slower \(n\), which is important for the pseudo-plateau bursting mechanism.

When studying bursting, numerical simulations can be used to observe the transitions from spiking to increasingly long bursts via spike-adding bifurcations. Here we show how the number of spikes per burst varies as a function of two parameters, the voltage-dependent calcium channel conductance (\(g_{Ca}\)) and the plasma membrane calcium ATPase pump rate (\(k_{PMCA}\)).

The following script runs the two-parameter sweep using the clODE FeatureSimulator to generate a spike counting diagram, as well as some trajectories of interest using the TrajectorySimulator.

>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from typing import List
>>> import clode
>>> from clode import exp
>>> def get_rhs(t: float,
...             x: List[float],
...             p: List[float],
...             dx: List[float],
...             aux: List[float],   
...             w: List[float]) -> None:    
...     v: float = x[0]
...     n: float = x[1]
...     c: float = x[2]
... 
...     gca: float = p[0]
...     gkca: float = p[1]
...     kpmca: float = p[2]
...     gk: float = 3500.0
... 
...     vca: float = 25.0
...     vk: float = -75.0
...     cm: float = 5300.0
...     alpha: float = 4.5e-6
...     fcyt: float = 0.01
...     kd: float = 0.4
...     vm: float = -20.0
...     sm: float = 12.0
...     vn: float = -16.0
...     sn: float = 5.0
...     taun: float = 20.0
...     
...     minf: float = 1.0/(1.0 + exp((vm - v)/sm))
...     ninf: float = 1.0/(1.0 + exp((vn - v)/sn))
...     omega: float = c**2/(c**2 + kd**2)
... 
...     ica: float = gca*minf*(v - vca)
...     ik: float = gk*n*(v - vk)
...     ikca: float = gkca*omega*(v - vk)
... 
...     dx[0] = -(ica + ik + ikca)/cm
...     dx[1] = (ninf - n)/taun
...     dx[2] = fcyt*(-alpha*ica - kpmca*c)
... 
>>> variables = {"v": -50.0, "n": 0.01, "c": 0.12}
>>> parameters = {"gca": 1200.0, "gkca": 750.0, "kpmca": 0.1}
>>> t_span=(0.0, 30000.0)
>>> integrator = clode.FeatureSimulator(
...     rhs_equation=get_rhs,
...     variables=variables,
...     parameters=parameters,
...     single_precision=True,
...     t_span=t_span,
...     stepper=clode.Stepper.dormand_prince,
...     dt=0.001,
...     dtmax=1.0,
...     abstol=1e-6,
...     reltol=1e-5,
...     event_var="v",
...     feature_var="v",
...     observer=clode.Observer.threshold_2,
...     observer_x_up_thresh=0.5,
...     observer_x_down_thresh=0.05,
...     observer_min_x_amp=1.0,
...     observer_min_imi=0.0,
...     observer_max_event_count=50,
... )
>>> nx = 64
>>> ny = 64
>>> nPts = nx * ny
>>> gca = np.linspace(550.0, 1050.0, nx)
>>> kpmca = np.linspace(0.095, 0.155, ny)
>>> px, py = np.meshgrid(gca, kpmca)
>>> ensemble_parameters = {"gca" : px.flatten(), "kpmca" : py.flatten()} #gkca will have default value
>>> ensemble_parameters_names = list(ensemble_parameters.keys())
>>> integrator.set_ensemble(parameters=ensemble_parameters)
>>> integrator.transient()
>>> integrator.features()
<clode.features.ObserverOutput object at 0x114f74980>
>>> features = integrator.get_observer_results()
>>> feature = features.get_var_max("peaks")
>>> feature = np.reshape(feature, (nx, ny))
>>> plt.pcolormesh(px, py, feature, shading='nearest', vmax=12)
<matplotlib.collections.QuadMesh object at 0x1152efa70>
>>> plt.title("peaks")
Text(0.5, 1.0, 'peaks')
>>> plt.colorbar()
<matplotlib.colorbar.Colorbar object at 0x114d692b0>
>>> plt.xlabel(ensemble_parameters_names[0])
Text(0.5, 0, 'gca')
>>> plt.ylabel(ensemble_parameters_names[1])
Text(0, 0.5, 'kpmca')
>>> plt.axis("tight")
(546.031746031746, 1053.968253968254, 0.09452380952380952, 0.1554761904761905)
>>> points = np.array([[950, 0.145], [700, 0.105], [750, 0.125], [800, 0.142]])
>>> plt.plot(points[:, 0], points[:, 1], 'o', color='black')
[<matplotlib.lines.Line2D object at 0x1104674a0>]
>>> for i, txt in enumerate(range(4)):
...     plt.annotate(txt, (points[i, 0] - 10, points[i, 1] - 0.003))
... 
Text(940.0, 0.142, '0')
Text(690.0, 0.102, '1')
Text(740.0, 0.122, '2')
Text(790.0, 0.13899999999999998, '3')
>>> plt.show()
>>> steps_taken = features.get_var_count("step")
>>> max_steps = int(np.max(steps_taken))
>>> integrator_traj = clode.TrajectorySimulator(
...     rhs_equation=get_rhs,
...     variables = variables,
...     parameters = parameters,
...     single_precision = True,
...     t_span=t_span,
...     stepper = clode.Stepper.dormand_prince,
...     dt = 0.001,
...     dtmax = 1.0,
...     abstol = 1e-6,
...     reltol = 1e-5,
...     max_steps = max_steps,
...     max_store = max_steps,
... )
>>> traj_parameters = {"gca":points[:, 0], "kpmca": points[:, 1]}
>>> integrator_traj.set_ensemble(parameters = traj_parameters)
>>> integrator_traj.transient()
>>> integrator_traj.set_tspan((0.0, 10000.0))
>>> integrator_traj.trajectory()
[TrajectoryResult(t=[0.00000000e+00 1.00000005e-03 6.00000052e-03 ... 9.99878125e+03
 9.99978125e+03 1.00007812e+04], x=[[-2.75978088e+01  7.66079351e-02  2.38926128e-01]
 [-2.75987015e+01  7.66085833e-02  2.38926560e-01]
 [-2.76031647e+01  7.66117945e-02  2.38928720e-01]
 ...
 [-5.11058693e+01  9.34081734e-04  2.30074629e-01]
 [-5.10095711e+01  9.32450930e-04  2.29968384e-01]
 [-5.09101524e+01  9.31767107e-04  2.29863733e-01]]), TrajectoryResult(t=[0.00000000e+00 1.00000005e-03 6.00000052e-03 ... 9.99878125e+03
 9.99978125e+03 1.00007812e+04], x=[[-2.69551754e+01  2.88627651e-02  1.73494622e-01]
 [-2.69547043e+01  2.88663507e-02  1.73495024e-01]
 [-2.69523487e+01  2.88842786e-02  1.73497051e-01]
 ...
 [-6.38102036e+01  6.54436881e-05  1.82811186e-01]
 [-6.37872810e+01  6.56906122e-05  1.82690158e-01]
 [-6.37641449e+01  6.59413708e-05  1.82569385e-01]]), TrajectoryResult(t=[0.00000000e+00 1.00000005e-03 6.00000052e-03 ... 9.99878125e+03
 9.99978125e+03 1.00007812e+04], x=[[-6.16074638e+01  1.66308440e-04  2.15597406e-01]
 [-6.16075211e+01  1.66305588e-04  2.15597227e-01]
 [-6.16078072e+01  1.66291327e-04  2.15596318e-01]
 ...
 [-3.33382721e+01  5.07195108e-02  2.00778529e-01]
 [-3.38711548e+01  4.96464521e-02  2.01008961e-01]
 [-3.43922539e+01  4.84891981e-02  2.01227486e-01]]), TrajectoryResult(t=[0.00000000e+00 1.00000005e-03 6.00000052e-03 ... 9.99878125e+03
 9.99978125e+03 1.00007812e+04], x=[[-3.19367046e+01  3.45831551e-02  2.36781478e-01]
 [-3.19369488e+01  3.45834084e-02  2.36781701e-01]
 [-3.19381714e+01  3.45846713e-02  2.36782789e-01]
 ...
 [-4.15628471e+01  5.96884731e-03  2.20044225e-01]
 [-4.13606377e+01  5.97556680e-03  2.20074490e-01]
 [-4.11497803e+01  5.99444145e-03  2.20108718e-01]])]
>>> trajectories = integrator_traj.get_trajectory()
>>> fig, ax = plt.subplots(4, 1, sharex=True, sharey=True)
>>> for i, trajectory in enumerate(trajectories):
...     ax[i].plot(trajectory.t / 1000.0, trajectory.x[:, 0])
... 
[<matplotlib.lines.Line2D object at 0x1179c9a30>]
[<matplotlib.lines.Line2D object at 0x1179c9b80>]
[<matplotlib.lines.Line2D object at 0x1179c9e80>]
[<matplotlib.lines.Line2D object at 0x1179ca150>]
>>> ax[1].set_ylabel("v")
Text(0, 0.5, 'v')
>>> ax[-1].set_xlabel('time (s)')
Text(0.5, 0, 'time (s)')
>>> plt.show()

Output

The spike counting diagram shows silent, spiking, and bursting regions. Chaotic dynamics occur near some of the spike-adding bifurcation boundaries (yellow indicates >=12 spikes per event, as detected with the threshold observer method). The trajectories associated with the numbered points are shown in the following figure.