def __init__(self, align='left', seq_length=None, dtype=tf.int32, name='FeedSequenceBatch'):
"""Create a Feedable SequenceBatch.
Args:
align (str): can be 'left' or 'right'. If 'left', values will be left-aligned, with padding on the right.
If 'right', values will be right-aligned, with padding on the left. Default is 'left'.
seq_length (int): the Tensor representing the SequenceBatch will have exactly this many columns. Default
is None. If None, seq_length will be dynamically determined.
dtype: data type of the SequenceBatch values array. Defaults to int32.
name (str): namescope for the Tensors created inside this Model.
"""
if align not in ('left', 'right'):
raise ValueError("align must be either 'left' or 'right'.")
self._align_right = (align == 'right')
self._seq_length = seq_length
with tf.name_scope(name):
values = tf.placeholder(dtype, shape=[None, None], name='values') # (batch_size, seq_length)
mask = tf.placeholder(tf.float32, shape=[None, None], name='mask') # (batch_size, seq_length)
if self._seq_length is not None:
# add static shape information
batch_dim, _ = values.get_shape()
new_shape = tf.TensorShape([batch_dim, tf.Dimension(seq_length)])
values.set_shape(new_shape)
mask.set_shape(new_shape)
super(FeedSequenceBatch, self).__init__(values, mask)
评论列表
文章目录