def __init__(self, values, mask, name='SequenceBatch'):
with tf.name_scope(name):
# check that dimensions are correct
values_shape = tf.shape(values)
mask_shape = tf.shape(mask)
values_shape_prefix = tf.slice(values_shape, [0], [2])
max_rank = max(values.get_shape().ndims, mask.get_shape().ndims)
assert_op = tf.assert_equal(values_shape_prefix, mask_shape,
data=[values_shape_prefix, mask_shape], summarize=max_rank,
name="assert_shape_prefix")
with tf.control_dependencies([assert_op]):
self._values = tf.identity(values, name='values')
self._mask = tf.identity(mask, name='mask')
评论列表
文章目录