def forward(self, data, stable_version=False):
"""input has each row as data vector; output also does so"""
count = 1
for bias, weight, pre_w, post_w in zip(self.biases, self.weights, self.pre_w, self.post_w):
size = pre_w[0].shape[0]
zeros_pre_w = T.zeros((size + 4, size + 4))
zeros_post_w = T.zeros((size + 4, size + 4))
pre_w_padding = T.set_subtensor(zeros_pre_w[2: size + 2, 2: size + 2], pre_w[0])
post_w_padding_T = T.set_subtensor(zeros_post_w[2: size + 2, 2: size + 2], post_w[0])
pre, updt = scan(process_pre_post_w, sequences=[pre_w_padding, zeros_pre_w])
post_T, updt = scan(process_pre_post_w, sequences=[post_w_padding_T, zeros_post_w])
pre, post_T = pre[2:size + 2, :], post_T[2:size + 2, :]
ori_shape = data.shape
data = T.reshape(data, (ori_shape[0], pre_w[0].shape[0], pre_w[0].shape[0]))
product, updt = scan(lambda x, A, B: T.dot(T.dot(A, x), B), sequences=data, non_sequences=[pre, post_T.T])
data = T.reshape(product, ori_shape)
if count < self.num_layers - 1:
data = T.nnet.relu(T.dot(data, weight) + bias)
elif not stable_version:
data = T.nnet.softmax(T.dot(data, weight) + bias)
else:
data = log_softmax(T.dot(data, weight) + bias)
count += 1
return data
评论列表
文章目录