def _add_blanks(y, blank_symbol, y_mask=None):
"""Add blanks to a matrix and updates mask
Input shape: output_seq_len x num_batch
Output shape: 2*output_seq_len+1 x num_batch
"""
# for y
y_extended = y.T.dimshuffle(0, 1, 'x')
blanks = tensor.zeros_like(y_extended) + blank_symbol
concat = tensor.concatenate([y_extended, blanks], axis=2)
res = concat.reshape((concat.shape[0],
concat.shape[1] * concat.shape[2])).T
begining_blanks = tensor.zeros((1, res.shape[1])) + blank_symbol
blanked_y = tensor.concatenate([begining_blanks, res], axis=0)
# for y_mask
if y_mask is not None:
y_mask_extended = y_mask.T.dimshuffle(0, 1, 'x')
concat = tensor.concatenate([y_mask_extended,
y_mask_extended], axis=2)
res = concat.reshape((concat.shape[0],
concat.shape[1] * concat.shape[2])).T
begining_blanks = tensor.ones((1, res.shape[1]), dtype=floatX)
blanked_y_mask = tensor.concatenate([begining_blanks, res], axis=0)
else:
blanked_y_mask = None
return blanked_y.astype('int32'), blanked_y_mask
评论列表
文章目录