def output_func(self, input):
q = input[0]
all_list = [q]
for i in xrange(self.position):
dot = T.batched_dot(q, T.dot(input[i + 1], self.W[i].T))
all_list.append(dot.dimshuffle(0, 'x'))
all_list.append(input[i + 1])
begin_index = self.position
for i in range(1, self.position):
for j in range(0, i):
dot = T.batched_dot(input[j + 1], T.dot(input[i + 1], self.W[begin_index].T))
all_list.append(dot.dimshuffle(0, 'x'))
#begin_index += 1
out = T.concatenate(all_list, axis=1)
# dot = T.batched_dot(q, T.batched_dot(a, self.W))
#dot = T.batched_dot(q, T.dot(a, self.W.T))
#out = T.concatenate([dot.dimshuffle(0, 'x'), q, a], axis=1)
return out
评论列表
文章目录