def forward(self, x):
length = lambda mx: int(mx.get_shape()[0])
with tf.variable_scope("QRNN/Forward"):
if self.c is None:
# init context cell
self.c = tf.zeros([length(x), self.kernel.size], dtype=tf.float32)
if self.conv_size <= 2:
# x is batch_size x sentence_length x word_length
# -> now, transpose it to sentence_length x batch_size x word_length
_x = tf.transpose(x, [1, 0, 2])
for i in range(length(_x)):
t = _x[i] # t is batch_size x word_length matrix
f, z, o = self.kernel.forward(t)
self._step(f, z, o)
else:
c_f, c_z, c_o = self.kernel.conv(x)
for i in range(length(c_f)):
f, z, o = c_f[i], c_z[i], c_o[i]
self._step(f, z, o)
return self.h
评论列表
文章目录