def __call__(self, x, split_into_variables=True, discard_context=False):
batchsize = x.shape[0]
seq_length = x.shape[3]
# conv
out_data = self.conv_blocks(x)
out_data = functions.reshape(out_data, (batchsize, -1, seq_length))
# rnn
for index, blocks in enumerate(self.rnn_blocks.blocks):
sru = blocks[0]
dropout = blocks[1] if len(blocks) == 2 else None
hidden, cell, context = sru(out_data, self.contexts[index])
if discard_context is False:
self.contexts[index] = context
if dropout is not None:
out_data = dropout(out_data)
# fc
out_data = self.dense_blocks(out_data)
assert out_data.shape[2] == seq_length
# CTC???????RNN???????Variable????????
if split_into_variables:
out_data = F.swapaxes(out_data, 1, 2)
out_data = F.reshape(out_data, (batchsize, -1))
out_data = F.split_axis(out_data, seq_length, axis=1)
else:
out_data = F.swapaxes(out_data, 1, 2)
out_data = F.squeeze(out_data, axis=2)
return out_data
评论列表
文章目录