def build_step(self, signals):
J = signals.gather(self.J_data)
states = [signals.gather(x) for x in self.state_data]
states_dtype = [x.dtype for x in self.state_data]
# note: we need to make sure that the previous call to this function
# has completed before the next starts, since we don't know that the
# functions are thread safe
with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"):
ret = tf.py_func(
self.neuron_step_math, [signals.dt, J] + states,
[self.output_data.dtype] + states_dtype,
name=self.neuron_step_math.__name__)
neuron_out, state_out = ret[0], ret[1:]
self.prev_result = [neuron_out]
neuron_out.set_shape(
self.output_data.shape + (signals.minibatch_size,))
signals.scatter(self.output_data, neuron_out)
for i, s in enumerate(self.state_data):
state_out[i].set_shape(s.shape + (signals.minibatch_size,))
signals.scatter(s, state_out[i])
评论列表
文章目录