I need some help with a personal project, and I would be really grateful if you provided me some guidance on it ;).
In short, I am trying to simulate the dynamics of a neuron with Julia, and these involves solving a system of 5 differential equations with Julia‘s DifferentialEquations library, and then plotting the results. Nevertheless, when I display the results I see that sometimes the currents go below 0, something strange because I am preventing the currents from going below 0 with the clamp()
function.
The main goal is to provide a Julia version of Brian2 for simulating Neuronal Dynamics, hence I am expected to do the same that this simulator does, but with Julia.
Here I leave you what I have done so far, as well as the reference python code of equations used by Brian2 and that I am trying to replicate in Julia.
This is the main file:
using BenchmarkTools
function Soma!(du,u,p,t)
I0, kappa, Ut = p[1]
Csoma_mem, Isoma_dpi_tau, alpha_soma, Isoma_pfb_gain, Isoma_pfb_th, Isoma_pfb_norm, Isoma_const = p[2]
Campa, Iampa_tau, alpha_ampa = p[3]
Cnmda, Inmda_tau, alpha_nmda, Inmda_thr = p[4]
Cgaba_a, Igaba_a_tau, alpha_gaba_a = p[5]
Cgaba_b, Igaba_b_tau, alpha_gaba_b = p[5]
Isoma_dpi_tau_shunt = Isoma_dpi_tau*(u[1] > I0) + I0*(u[1] <= I0)
Isoma_dpi_g_shunt = alpha_soma*Isoma_dpi_tau_shunt*(u[1] > I0) + I0*(u[1] <= I0)
Isoma_pfb = Isoma_pfb_gain/(1+exp(-(u[1] - Isoma_pfb_th)/Isoma_pfb_norm))
Isoma_pfb_shunt = Isoma_pfb*(u[1] > I0) + I0*(u[1] <= I0)
Igaba_a_shunt = u[5]*(u[1]>I0) + I0*(u[1]<=I0)
Isoma_sum = Isoma_dpi_tau_shunt + Igaba_a_shunt - Isoma_pfb_shunt - I0*(u[1] <=I0)
tau_soma = (Csoma_mem * Ut) / (kappa * Isoma_dpi_tau_shunt)
Iampa_g = alpha_ampa * Iampa_tau # Synapse gain term expressed in terms of its tau current
Iampa_tau_shunt = Iampa_tau*(u[2]>I0) + I0*(u[2]<=I0) # Shunt tau current if Iampa goes to I0
Iampa_g_shunt = Iampa_g*(u[2]>I0) + I0*(u[2]<=I0) # Shunt g current if Iampa goes to I0
tau_ampa = (Campa * Ut) / (kappa * Iampa_tau_shunt) # Synaptic time-constant
Igaba_a_tau_shunt = Igaba_a_tau*(u[4] > I0) + I0*(u[4] <= I0)
Igaba_a_g_shunt = alpha_gaba_a*Igaba_a_tau*(u[4] > I0) + I0*(u[4] <= I0)
tau_gaba_a = (Cgaba_a * Ut) / (kappa * Igaba_a_tau_shunt)
Igaba_b_tau_shunt = Igaba_b_tau*(u[5] > I0) + I0*(u[5] <= I0)
Igaba_b_g_shunt = alpha_gaba_b*Igaba_b_tau*(u[5] > I0) + I0*(u[5] <= I0)
tau_gaba_b = (Cgaba_b * Ut) / (kappa * Igaba_b_tau_shunt)
Isoma_mem_clip = clamp(u[1], I0, 1) #Works
Inmda_g = alpha_nmda * Inmda_tau
Inmda_g_shunt = Inmda_g*Int((u[3]>I0)) + I0*Int((u[3]<=I0))
Inmda_tau_shunt = Inmda_tau*Int((u[3]>I0)) + I0*Int((u[3]<=I0))
tau_nmda = Cnmda * Ut /(kappa * Inmda_tau_shunt)
Inmda_dp = u[3]/(1 + Inmda_thr / Isoma_mem_clip)
Iin_clip = clamp(Inmda_dp + u[2] - u[5] + Isoma_const, I0, 1) # Works
du[1] = (alpha_soma*(Iin_clip - Isoma_sum) - (Isoma_sum - I0*Int((u[1] <= I0)))*(Isoma_mem_clip/Isoma_dpi_tau_shunt))/(tau_soma * (1 + (Isoma_dpi_g_shunt/Isoma_mem_clip)))
du[2] = (- u[2] - Iampa_g_shunt + 2*I0*(u[2] <= I0))/(tau_ampa * ((Iampa_g_shunt / u[2]) + 1)) #AMPA
du[3] = (- u[3] - Inmda_g_shunt + 2*I0*(u[3] <= I0))/(tau_nmda * ((Inmda_g_shunt / u[3]) + 1)) #NMDA
du[4] = (- u[4] - Igaba_a_g_shunt + 2*I0*(u[4] <= I0))/(tau_gaba_a * ((Igaba_a_g_shunt / u[4]) + 1))
du[5] = (- u[5] - Igaba_b_g_shunt + 2*I0*(u[5] <= I0))/(tau_gaba_b * ((Igaba_b_g_shunt / u[5]) + 1))
end
kappa = (kappa_n + kappa_p)/2
p1 = [I0, kappa, Ut]
p2 = [Csoma_mem, Isoma_dpi_tau, alpha_soma, Isoma_pfb_gain, Isoma_pfb_th, Isoma_pfb_norm, Isoma_const]
p3 = [Campa, Iampa_tau, alpha_ampa]
p4 = [Cnmda, Inmda_tau, alpha_nmda, Inmda_thr]
p5 = [Cgaba_a, Igaba_a_tau, alpha_gaba_a]
p6 = [Cgaba_b, Igaba_b_tau, alpha_gaba_b]
p = [p1,p2,p3,p4, p5, p6]
inp_duration = 5 #Input duration must be ≥ pulse_stop ???
pulse_start = 0
pulse_stop = 5
rate = 80
input_type = "regular"
inp = SpikeGen.input_gen(input_type,inp_duration, pulse_start, pulse_stop, rate)
indices = inp .== 1.0
tvalues = 0:pulse_stop/4999:pulse_stop
time_spikes = tvalues[indices]
function conditionSOMA(u, t, integrator)
u[1] <= Isoma_th
end
function affectSOMA!(integrator)
println("Affecting soma at time $(integrator.t)")
integrator.u[1] = Isoma_reset
end
cbSOMA = ContinuousCallback(conditionSOMA, affectSOMA!)
weight_AMPA = 2
conditionAMPA(u,t,integrator) = t ∈ time_spikes
affectAMPA!(integrator) = integrator.u[2] += weight_AMPA * alpha_ampa * Iampa_w0
cbAMPA = DiscreteCallback(conditionAMPA, affectAMPA!)
weight_NMDA = 0.5
#conditionNMDA(u,t,integrator) = t ∈ [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0, 1.5,2.0]
conditionNMDA(u,t,integrator) = t ∈ time_spikes
affectNMDA!(integrator) = integrator.u[3] += weight_NMDA * alpha_nmda * Inmda_w0
cbNMDA = DiscreteCallback(conditionNMDA, affectNMDA!)
weight_gaba_a = 1
#conditionGABA_A(u,t,integrator) = t ∈ [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0, 1.5,2.0]
conditionGABA_A(u,t,integrator) = t ∈ time_spikes
affectGABA_A!(integrator) = integrator.u[4] += weight_gaba_a * alpha_gaba_a * Igaba_a_w0
cbGABA_A = DiscreteCallback(conditionGABA_A, affectGABA_A!)
weight_gaba_b = 1
conditionGABA_B(u,t,integrator) = t ∈ time_spikes
affectGABA_B!(integrator) = integrator.u[5] += Igaba_b_w0*weight_gaba_b*alpha_gaba_b
cbGABA_B = DiscreteCallback(conditionGABA_B, affectGABA_B!)
callbacks = CallbackSet(cbSOMA, cbAMPA, cbNMDA, cbGABA_A, cbGABA_B)
tspan = (0.0, 5.0)
ini_soma = Isoma_reset
ini_soma = 0.1I0
ini_AMPA = 1.1I0
ini_NMDA = 1.1I0
ini_GABA_A = 1.1I0
ini_GABA_B = 1.1I0
u0 = [ini_soma, ini_AMPA, ini_NMDA, ini_GABA_A, ini_GABA_B]
prob = ODEProblem(Soma!, u0, tspan, p)
sol = solve(prob, Rodas5(), callback = callbacks, saveat = 0.01)
strings = ["Membrane current", "AMPA", "NMDA", "GABA_A", "GABA_B"]
plot(sol, xlabel = "Time [ms]")
plot1 = plot(sol, idxs = (1), title = strings[1], xlabel = "Time [s]", ylabel = "Isoma [A]")
plot2 = plot(sol, idxs = (2), title = strings[2], xlabel = "Time [s]", ylabel = "AMPA [A]")
plot3 = plot(sol, idxs = (3), title = strings[3], xlabel = "Time [s]", ylabel = "NMDA [A]")
plot4 = plot(sol, idxs = (4), title = strings[4], xlabel = "Time [s]", ylabel = "GABA A [A]")
plot5 = plot(sol, idxs = (5), title = strings[5], xlabel = "Time [s]", ylabel = "GABA B [A]")
P1 = plot(plot1)
P2 = plot(plot2,plot3, layout = (2,1))
P3 = plot(plot4,plot5, layout = (2,1))
index = 2 # 1 - Soma, 2 - AMPA & NMDA, 3 - GABAs
if index == 1
plot(P1)
elseif index == 2
plot(P2)
elseif index == 3
plot(P3)
else
nothing;
end
This is Setup.jl
, in which I have the values of the parameters:
using Pkg
using DifferentialEquations
using Plots
##########
#Constants
##########
kappa_n = 0.75
kappa_p = 0.66
kappa = (kappa_n + kappa_p)/2
Ut = 25e-3
I0 = 1e-12
################
#Scaling factors
################
alpha_soma = 4
alpha_ahp = 4
alpha_nmda = 4
alpha_ampa = 4
alpha_gaba_a = 4
alpha_gaba_b = 4
##################
#Neuron parameters
###################
#Soma
Csoma_mem = 2e-12
Isoma_mem = 1.1I0
Isoma_dpi_tau = 5I0
Isoma_th = 2000*I0
Isoma_reset = 1.2*I0
Isoma_const = I0
soma_refP = 5e-3
#Adaptation
Csoma_ahp = 4e-12
Isoma_ahp_tau = 2I0
Isoma_ahp_w = 1I0
#Positive feedback
Isoma_pfb_gain = 100I0
Isoma_pfb_th = 1000I0
Isoma_pfb_norm = 20I0
##################
#Synapse parameters
###################
#Slow_exc, NMDA
Cnmda = 2e-12
Inmda_tau = 2I0
Inmda_w0 = 10I0
Inmda = I0
Inmda_thr = I0
Inmda_g = alpha_nmda * Inmda_tau
#Fast_exc, AMPA
Campa = 2e-12;
Iampa_tau = 20I0
Iampa_w0 = 100I0
Iampa_g = alpha_ampa * Iampa_tau
#INH, SLOW_INH, GABA_B, subtractive
Cgaba_b = 2e-12;
Igaba_b_tau = 5I0
Igaba_b_w0 = 100I0
Igaba_b_g = alpha_gaba_b * Igaba_b_tau
#FAST_INH, GABA_A, shunting, a mixture of subatractive and divisive
#Positive feedback
Cgaba_a = 2e-12
Igaba_a_tau = 5I0
Igaba_a_w0 = 100I0
Igaba_a_g = alpha_gaba_a * Igaba_a_tau
#Voltage
taum = 20e-3
tauw = 30e-3
a = -0.5e-9
urest = -55e-3
urh = -55e-3
At = 2e-3
R = 500e6
This is SpikeGent.jl
, which generates a spike pattern to affect the synapses in each spike time:
module SpikeGen
export input_gen
function input_gen(input_type, inp_duration, pulse_start, pulse_stop, rate)
if input_type == "regular"
inp = reg_gen(inp_duration, pulse_start, pulse_stop, rate)
elseif input_type == "poisson"
inp = poisson_gen(inp_duration, pulse_start, pulse_stop, rate)
elseif input_type == "cosine"
inp = cos_gen(inp_duration, pulse_start, pulse_stop, rate)
else
error("Invalid input type")
end
return inp
end
function reg_gen(inp_duration, pulse_start, pulse_stop,rate)
# Regular input generator
inp = zeros(inp_duration*1000)
dt = Int(round(1000/rate))
inp[pulse_start*1000+1:dt:pulse_stop*1000] .= 1.0
return inp
end
function poisson_gen(inp_duration, pulse_start, pulse_stop,rate)
# Poisson input generator
prob = rate*1e-3
mask = rand(inp_duration*1000)
spikes = zeros(inp_duration*1000)
spikes[mask .< prob] .= 1.0
spikes[1:pulse_start*1000] .= 0 #I have modified this one
spikes[pulse_stop*1000:end] .=0
return spikes
end
function cos_gen(inp_duration, pulse_start, pulse_stop,rate)
spikes = zeros(inp_duration * 1000)
time = LinRange(0, inp_duration, inp_duration * 1000)
co = cos.(2 * π * rate .* time)
mask = 20 .* rand(inp_duration * 1000)
spikes[mask .< co] .= 1
spikes[1:pulse_start * 1000] .= 0
spikes[pulse_stop * 1000:end] .= 0
return spikes
end
end
And these are the plots I am currently getting:
Soma plot
AMPA and NMDA plot
GABA A and GABA B plot
Whereas these are the expected plots:
Membrane current expected behaviour
AMPA and NMDA expected plot
AMPA and NMDA expected frequency plots
Moreover, if you have some knowledge on Neuroscience and therefore think that any parameter or initial value does not make sense, I am open to any suggestion!
Many thanks in advanced, and I am eagerly looking forward to receiving your responses!
Kind regards.