def build_step(self, signals):
J = signals.gather(self.J_data)
voltage = signals.gather(self.voltage_data)
refractory = signals.gather(self.refractory_data)
refractory -= signals.dt
delta_t = tf.clip_by_value(signals.dt - refractory, self.zero,
signals.dt)
voltage -= (J - voltage) * tf.expm1(-delta_t / self.tau_rc)
spiked = voltage > self.one
spikes = tf.cast(spiked, signals.dtype) * self.amplitude
signals.scatter(self.output_data, spikes)
t_spike = (self.tau_ref + signals.dt +
self.tau_rc * tf.log1p((self.one - voltage) /
(J - self.one)))
refractory = tf.where(spiked, t_spike, refractory)
signals.mark_gather(self.J_data)
signals.scatter(self.refractory_data, refractory)
voltage = tf.where(spiked, self.zeros,
tf.maximum(voltage, self.min_voltage))
signals.scatter(self.voltage_data, voltage)
评论列表
文章目录