def run_trials(self, trial_input, batch_size, t_connectivity = None, use_input = True):
rnn_inputs = np.split(trial_input, trial_input.shape[1], axis=1)
state = np.expand_dims(self.init_state[0, :], 0)
state = np.repeat(state, batch_size, 0)
rnn_outputs = []
rnn_states = []
for rnn_input in rnn_inputs:
if t_connectivity:
output, state = self.rnn_step(state, rnn_input, t_connectivity[i], use_input)
else:
output, state = self.rnn_step(state, rnn_input, np.ones_like(self.W_rec), use_input)
rnn_outputs.append(output)
rnn_states.append(state)
return np.array(rnn_outputs), np.array(rnn_states)
评论列表
文章目录