def run_trial(self, trial_input, t_connectivity = None, use_input = True):
rnn_inputs = np.split(trial_input, trial_input.shape[0], axis=0)
state = np.expand_dims(self.init_state[0, :], 0)
rnn_outputs = []
rnn_states = []
for i, rnn_input in enumerate(rnn_inputs):
if t_connectivity:
output, state = self.rnn_step(state, rnn_input, t_connectivity[i], use_input)
else:
output, state = self.rnn_step(state, rnn_input, np.ones_like(self.W_rec), use_input)
rnn_outputs.append(output)
rnn_states.append(state)
return np.array(rnn_outputs), np.array(rnn_states)
# apply the RNN to a whole batch of inputs
评论列表
文章目录